Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM] add rocm support #1411

Merged
merged 6 commits into from
Apr 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ endif()

project(torchaudio)

# Find the HIP package, set the HIP paths, load the HIP CMake.
if(USE_ROCM)
include(cmake/LoadHIP.cmake)
if(NOT PYTORCH_FOUND_HIP)
set(USE_ROCM OFF)
endif()
endif()

# check and set CMAKE_CXX_STANDARD
string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard)
if(env_cxx_standard GREATER -1)
Expand Down
2 changes: 2 additions & 0 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _get_build(var, default=False):
_BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX")
_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True)
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
_USE_ROCM = _get_build("USE_ROCM")


def get_ext_modules():
Expand Down Expand Up @@ -74,6 +75,7 @@ def build_extension(self, ext):
f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}",
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}",
]
build_args = [
'--target', 'install'
Expand Down
234 changes: 234 additions & 0 deletions cmake/LoadHIP.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
set(PYTORCH_FOUND_HIP FALSE)

if(NOT DEFINED ENV{ROCM_PATH})
set(ROCM_PATH /opt/rocm)
else()
set(ROCM_PATH $ENV{ROCM_PATH})
endif()

# HIP_PATH
if(NOT DEFINED ENV{HIP_PATH})
set(HIP_PATH ${ROCM_PATH}/hip)
else()
set(HIP_PATH $ENV{HIP_PATH})
endif()

if(NOT EXISTS ${HIP_PATH})
return()
endif()

# HCC_PATH
if(NOT DEFINED ENV{HCC_PATH})
set(HCC_PATH ${ROCM_PATH}/hcc)
else()
set(HCC_PATH $ENV{HCC_PATH})
endif()

# HSA_PATH
if(NOT DEFINED ENV{HSA_PATH})
set(HSA_PATH ${ROCM_PATH}/hsa)
else()
set(HSA_PATH $ENV{HSA_PATH})
endif()

# ROCBLAS_PATH
if(NOT DEFINED ENV{ROCBLAS_PATH})
set(ROCBLAS_PATH ${ROCM_PATH}/rocblas)
else()
set(ROCBLAS_PATH $ENV{ROCBLAS_PATH})
endif()

# ROCFFT_PATH
if(NOT DEFINED ENV{ROCFFT_PATH})
set(ROCFFT_PATH ${ROCM_PATH}/rocfft)
else()
set(ROCFFT_PATH $ENV{ROCFFT_PATH})
endif()

# HIPFFT_PATH
if(NOT DEFINED ENV{HIPFFT_PATH})
set(HIPFFT_PATH ${ROCM_PATH}/hipfft)
else()
set(HIPFFT_PATH $ENV{HIPFFT_PATH})
endif()

# HIPSPARSE_PATH
if(NOT DEFINED ENV{HIPSPARSE_PATH})
set(HIPSPARSE_PATH ${ROCM_PATH}/hipsparse)
else()
set(HIPSPARSE_PATH $ENV{HIPSPARSE_PATH})
endif()

# THRUST_PATH
if(DEFINED ENV{THRUST_PATH})
set(THRUST_PATH $ENV{THRUST_PATH})
else()
set(THRUST_PATH ${ROCM_PATH}/include)
endif()

# HIPRAND_PATH
if(NOT DEFINED ENV{HIPRAND_PATH})
set(HIPRAND_PATH ${ROCM_PATH}/hiprand)
else()
set(HIPRAND_PATH $ENV{HIPRAND_PATH})
endif()

# ROCRAND_PATH
if(NOT DEFINED ENV{ROCRAND_PATH})
set(ROCRAND_PATH ${ROCM_PATH}/rocrand)
else()
set(ROCRAND_PATH $ENV{ROCRAND_PATH})
endif()

# MIOPEN_PATH
if(NOT DEFINED ENV{MIOPEN_PATH})
set(MIOPEN_PATH ${ROCM_PATH}/miopen)
else()
set(MIOPEN_PATH $ENV{MIOPEN_PATH})
endif()

# RCCL_PATH
if(NOT DEFINED ENV{RCCL_PATH})
set(RCCL_PATH ${ROCM_PATH}/rccl)
else()
set(RCCL_PATH $ENV{RCCL_PATH})
endif()

# ROCPRIM_PATH
if(NOT DEFINED ENV{ROCPRIM_PATH})
set(ROCPRIM_PATH ${ROCM_PATH}/rocprim)
else()
set(ROCPRIM_PATH $ENV{ROCPRIM_PATH})
endif()

# HIPCUB_PATH
if(NOT DEFINED ENV{HIPCUB_PATH})
set(HIPCUB_PATH ${ROCM_PATH}/hipcub)
else()
set(HIPCUB_PATH $ENV{HIPCUB_PATH})
endif()

# ROCTHRUST_PATH
if(NOT DEFINED ENV{ROCTHRUST_PATH})
set(ROCTHRUST_PATH ${ROCM_PATH}/rocthrust)
else()
set(ROCTHRUST_PATH $ENV{ROCTHRUST_PATH})
endif()

# ROCTRACER_PATH
if(NOT DEFINED ENV{ROCTRACER_PATH})
set(ROCTRACER_PATH ${ROCM_PATH}/roctracer)
else()
set(ROCTRACER_PATH $ENV{ROCTRACER_PATH})
endif()

if(NOT DEFINED ENV{PYTORCH_ROCM_ARCH})
set(PYTORCH_ROCM_ARCH gfx803;gfx900;gfx906;gfx908)
else()
set(PYTORCH_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH})
endif()

# Add HIP to the CMAKE Module Path
set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})

# Disable Asserts In Code (Can't use asserts on HIP stack.)
add_definitions(-DNDEBUG)

macro(find_package_and_print_version PACKAGE_NAME)
find_package("${PACKAGE_NAME}" ${ARGN})
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
endmacro()

# Find the HIP Package
find_package_and_print_version(HIP 1.0)

if(HIP_FOUND)
set(PYTORCH_FOUND_HIP TRUE)

# Find ROCM version for checks
file(READ "${ROCM_PATH}/.info/version-dev" ROCM_VERSION_DEV_RAW)
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW})
if(ROCM_VERSION_DEV_MATCH)
set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
endif()
message("\n***** ROCm version from ${ROCM_PATH}/.info/version-dev ****\n")
message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}")

message("\n***** Library versions from dpkg *****\n")
execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep hip_base COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}")

message("\n***** Library versions from cmake find_package *****\n")

set(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
set(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###

set(hip_DIR ${HIP_PATH}/lib/cmake/hip)
set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64)
set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs)
set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr)
set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand)
set(hiprand_DIR ${HIPRAND_PATH}/lib/cmake/hiprand)
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen)
set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft)
set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft)
set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse)
set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl)
set(rocprim_DIR ${ROCPRIM_PATH}/lib/cmake/rocprim)
set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub)
set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust)

find_package_and_print_version(hip REQUIRED)
find_package_and_print_version(hsa-runtime64 REQUIRED)
find_package_and_print_version(amd_comgr REQUIRED)
find_package_and_print_version(rocrand REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
find_package_and_print_version(rocblas REQUIRED)
find_package_and_print_version(miopen REQUIRED)
find_package_and_print_version(rocfft REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
find_package_and_print_version(hipfft REQUIRED)
endif()
find_package_and_print_version(hipsparse REQUIRED)
find_package_and_print_version(rccl)
find_package_and_print_version(rocprim REQUIRED)
find_package_and_print_version(hipcub REQUIRED)
find_package_and_print_version(rocthrust REQUIRED)

if(HIP_COMPILER STREQUAL clang)
set(hip_library_name amdhip64)
else()
set(hip_library_name hip_hcc)
endif()
message("HIP library name: ${hip_library_name}")

# TODO: hip_hcc has an interface include flag "-hc" which is only
# recognizable by hcc, but not gcc and clang. Right now in our
# setup, hcc is only used for linking, but it should be used to
# compile the *_hip.cc files as well.
find_library(PYTORCH_HIP_HCC_LIBRARIES ${hip_library_name} HINTS ${HIP_PATH}/lib)
# TODO: miopen_LIBRARIES should return fullpath to the library file,
# however currently it's just the lib name
find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib)
# TODO: rccl_LIBRARIES should return fullpath to the library file,
# however currently it's just the lib name
find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib)
# hiprtc is part of HIP
find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib)
# roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib)
set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include)
endif()
2 changes: 2 additions & 0 deletions test/torchaudio_unittest/backend/soundfile/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_wav_data,
load_wav,
nested_params,
skipIfRocm,
)
from .common import (
fetch_wav_subtype,
Expand Down Expand Up @@ -280,6 +281,7 @@ def test_fileobj_wav(self):
self._test_fileobj('wav')

@skipIfFormatNotSupported("FLAC")
@skipIfRocm
def test_fileobj_flac(self):
"""Saving audio via file-like object works"""
self._test_fileobj('flac')
Expand Down
5 changes: 3 additions & 2 deletions test/torchaudio_unittest/common_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
skipIfNoModule,
skipIfNoKaldi,
skipIfNoSox,
skipIfRocm,
)
from .wav_utils import (
get_wav_data,
Expand All @@ -32,5 +33,5 @@
__all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend',
'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoKaldi', 'skipIfNoSox',
'skipIfNoSoxBackend', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params',
'nested_params']
'skipIfNoSoxBackend', 'skipIfRocm', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav',
'load_params', 'nested_params']
2 changes: 2 additions & 0 deletions test/torchaudio_unittest/common_utils/case_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,5 @@ def skipIfNoModule(module, display_name=None):
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available')
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available')
skipIfRocm = unittest.skipIf(os.getenv('TORCHAUDIO_TEST_WITH_ROCM', '0') == '1',
reason="test doesn't currently work on the ROCm stack")
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import torchaudio.functional as F

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
skipIfRocm,
)


class Functional(common_utils.TestBaseMixin):
Expand Down Expand Up @@ -34,6 +37,7 @@ def func(tensor):
tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)

@skipIfRocm
def test_griffinlim(self):
def func(tensor):
n_fft = 400
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torchaudio.transforms as T

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
skipIfRocm,
)


class Transforms(common_utils.TestBaseMixin):
Expand All @@ -21,6 +24,7 @@ def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(), tensor)

@skipIfRocm
def test_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)
Expand Down