Skip to content

Commit

Permalink
[Contrib] Add MKL DNN
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Nov 2, 2019
1 parent e4c00a3 commit 84330c3
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 1 deletion.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
# Contrib library options
tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none)
tvm_option(USE_MKL_DNN "Build with MKL DNN" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ set(USE_BLAS none)
# set(USE_MKL_PATH <path to venv or site-packages directory>) if using `pip install mkl`
set(USE_MKL_PATH none)

# Whether use MKL DNN library
set(USE_MKL_DNN OFF)

# Whether use OpenMP thread pool, choices: gnu, intel
# Note: "gnu" uses gomp library, "intel" uses iomp5 library
set(USE_OPENMP none)
Expand Down
8 changes: 8 additions & 0 deletions cmake/modules/contrib/BLAS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,11 @@ elseif(USE_BLAS STREQUAL "none")
else()
message(FATAL_ERROR "Invalid option: USE_BLAS=" ${USE_BLAS})
endif()

if(USE_MKL_DNN STREQUAL "ON")
find_library(BLAS_LIBRARY_MKLDNN dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY_MKLDNN})
list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC})
add_definitions(-DUSE_MKL_DNN=1)
message(STATUS "Use MKL DNN library " ${BLAS_LIBRARY_MKLDNN})
endif()
10 changes: 10 additions & 0 deletions src/runtime/contrib/cblas/cblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ extern "C" {
#else
#include <cblas.h>
#endif
#if USE_MKL_DNN == 1
#include <dnnl.h>
#endif
}

namespace tvm {
Expand All @@ -40,12 +43,19 @@ using namespace runtime;

inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }

inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; }

struct CblasSgemmOp {
typedef float TDatatype;
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
int ldb, float beta, float* C, int ldc) {
#if USE_MKL_DNN == 1
dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B,
ldb, A, lda, beta, C, ldc);
#else
cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
#endif
}
};

Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
if bias is not None:
C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j],
tag=tag.BROADCAST)
return C

Expand Down

0 comments on commit 84330c3

Please sign in to comment.