From 7f6c1a95ddbe213e03c1026a533b8c75568e2dd0 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 10 Aug 2020 19:16:41 -0700 Subject: [PATCH] [Topi,x86] Split MKL from BLAS. (#6182) Make cblas and mkl seperate entities in cmake and topi, allowing users to use both a BLAS library and MKL. In the future, MKL specific functions can be added easily. MKLDNN is also split off from MKL and BLAS for the same reasons. Other improvements: - cblas and mkl strategies are now only applied when they are viable. - compile_engine will log which implementation it has chosen and why. --- CMakeLists.txt | 2 +- cmake/config.cmake | 16 +- cmake/modules/contrib/BLAS.cmake | 56 +++--- python/tvm/contrib/cblas.py | 33 ---- python/tvm/contrib/mkl.py | 126 +++++++++++++ python/tvm/contrib/mkldnn.py | 52 ++++++ python/tvm/relay/backend/compile_engine.py | 27 ++- python/tvm/relay/op/strategy/x86.py | 30 +++- python/tvm/topi/x86/dense.py | 49 ++++- src/runtime/contrib/cblas/cblas.cc | 83 +-------- src/runtime/contrib/cblas/mkl.cc | 198 +++++++++++++++++++++ src/runtime/contrib/cblas/mkldnn.cc | 56 ++++++ tests/python/contrib/test_cblas.py | 79 +++++--- 13 files changed, 621 insertions(+), 186 deletions(-) create mode 100644 python/tvm/contrib/mkl.py create mode 100644 python/tvm/contrib/mkldnn.py create mode 100644 src/runtime/contrib/cblas/mkl.cc create mode 100644 src/runtime/contrib/cblas/mkldnn.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index dbe27f5a7edb5..fa3c2b4b4ad0d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,7 +55,7 @@ tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson") # Contrib library options tvm_option(USE_BLAS "The blas library to be linked" none) -tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none) +tvm_option(USE_MKL "MKL root path when use MKL blas" OFF) tvm_option(USE_MKLDNN "Build with MKLDNN" OFF) tvm_option(USE_DNNL_CODEGEN "Enable MKLDNN (DNNL) codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) diff --git a/cmake/config.cmake b/cmake/config.cmake index 73c6328e6aaf4..47f20372e906e 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -123,14 +123,18 @@ set(USE_LLVM OFF) #--------------------------------------------- # Contrib libraries #--------------------------------------------- -# Whether use BLAS, choices: openblas, mkl, atlas, apple +# Whether use BLAS, choices: openblas, atlas, apple set(USE_BLAS none) -# /path/to/mkl: mkl root path when use mkl blas library -# set(USE_MKL_PATH /opt/intel/mkl) for UNIX -# set(USE_MKL_PATH ../IntelSWTools/compilers_and_libraries_2018/windows/mkl) for WIN32 -# set(USE_MKL_PATH ) if using `pip install mkl` -set(USE_MKL_PATH none) +# Whether to use MKL +# Possible values: +# - ON: Enable MKL +# - /path/to/mkl: mkl root path +# - OFF: Disable MKL +# set(USE_MKL /opt/intel/mkl) for UNIX +# set(USE_MKL ../IntelSWTools/compilers_and_libraries_2018/windows/mkl) for WIN32 +# set(USE_MKL ) if using `pip install mkl` +set(USE_MKL OFF) # Whether use MKLDNN library, choices: ON, OFF, path to mkldnn library set(USE_MKLDNN OFF) diff --git a/cmake/modules/contrib/BLAS.cmake b/cmake/modules/contrib/BLAS.cmake index c2e1fd65743b8..e8c8e22e33346 100644 --- a/cmake/modules/contrib/BLAS.cmake +++ b/cmake/modules/contrib/BLAS.cmake @@ -15,47 +15,55 @@ # specific language governing permissions and limitations # under the License. -# Plugin rules for cblas -file(GLOB CBLAS_CONTRIB_SRC src/runtime/contrib/cblas/*.cc) - if(USE_BLAS STREQUAL "openblas") find_library(BLAS_LIBRARY openblas) list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY}) - list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC}) - message(STATUS "Use BLAS library " ${BLAS_LIBRARY}) -elseif(USE_BLAS STREQUAL "mkl") - if(NOT IS_DIRECTORY ${USE_MKL_PATH}) - set(USE_MKL_PATH /opt/intel/mkl) - endif() - if(APPLE) - find_library(BLAS_LIBRARY_MKL NAMES mklml mkl_rt HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) - elseif(UNIX) - find_library(BLAS_LIBRARY_MKL NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) - elseif(MSVC) - find_library(BLAS_LIBRARY_MKL NAMES mkl_rt HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64_win) - endif() - include_directories(${USE_MKL_PATH}/include) - list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY_MKL}) - list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC}) - add_definitions(-DUSE_MKL_BLAS=1) - message(STATUS "Use BLAS library " ${BLAS_LIBRARY_MKL}) + list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/cblas.cc) + message(STATUS "Using BLAS library " ${BLAS_LIBRARY}) elseif(USE_BLAS STREQUAL "atlas" OR USE_BLAS STREQUAL "blas") find_library(BLAS_LIBRARY cblas) list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY}) - list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC}) + list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/cblas.cc) message(STATUS "Use BLAS library " ${BLAS_LIBRARY}) elseif(USE_BLAS STREQUAL "apple") find_library(BLAS_LIBRARY Accelerate) include_directories(${BLAS_LIBRARY}/Versions/Current/Frameworks/vecLib.framework/Versions/Current/Headers/) list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY}) - list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC}) + list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/cblas.cc) message(STATUS "Use BLAS library " ${BLAS_LIBRARY}) +elseif(USE_BLAS STREQUAL "mkl") + message(DEPRECATION "USE_BLAS=mkl is deprecated. Use USE_MKL=ON instead.") + set(USE_MKL ON) elseif(USE_BLAS STREQUAL "none") # pass else() message(FATAL_ERROR "Invalid option: USE_BLAS=" ${USE_BLAS}) endif() +if(USE_MKL OR USE_MKL_PATH) + if(USE_MKL_PATH) + message(DEPRECATION "USE_MKL_PATH=${USE_MKL_PATH} is deprecated. Use USE_MKL=${USE_MKL_PATH} instead.") + endif() + if(NOT USE_MKL) + set(USE_MKL ${USE_MKL_PATH}) + endif() + if(NOT IS_DIRECTORY ${USE_MKL}) + set(USE_MKL /opt/intel/mkl) + endif() + if(APPLE) + find_library(BLAS_LIBRARY_MKL NAMES mklml mkl_rt HINTS ${USE_MKL}/lib/ ${USE_MKL}/lib/intel64) + elseif(UNIX) + find_library(BLAS_LIBRARY_MKL NAMES mkl_rt mklml_gnu HINTS ${USE_MKL}/lib/ ${USE_MKL}/lib/intel64) + elseif(MSVC) + find_library(BLAS_LIBRARY_MKL NAMES mkl_rt HINTS ${USE_MKL}/lib/ ${USE_MKL}/lib/intel64_win) + endif() + include_directories(${USE_MKL}/include) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY_MKL}) + list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/mkl.cc) + add_definitions(-DUSE_MKL_BLAS=1) + message(STATUS "Use MKL library " ${BLAS_LIBRARY_MKL}) +endif() + if(IS_DIRECTORY ${USE_MKLDNN}) find_library(MKLDNN_LIBRARY NAMES dnnl HINTS ${USE_MKLDNN}/lib/) if (MKLDNN_LIBRARY STREQUAL "MKLDNN_LIBRARY-NOTFOUND") @@ -63,6 +71,7 @@ if(IS_DIRECTORY ${USE_MKLDNN}) else() include_directories(${USE_MKLDNN}/include) list(APPEND TVM_RUNTIME_LINKER_LIBS ${MKLDNN_LIBRARY}) + list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/mkldnn.cc) add_definitions(-DUSE_DNNL=1) message(STATUS "Use MKLDNN library " ${MKLDNN_LIBRARY}) endif() @@ -74,6 +83,7 @@ elseif(USE_MKLDNN STREQUAL "ON") list(APPEND TVM_RUNTIME_LINKER_LIBS ${MKLDNN_LIBRARY}) add_definitions(-DUSE_DNNL=1) message(STATUS "Use MKLDNN library " ${MKLDNN_LIBRARY}) + list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/mkldnn.cc) endif() elseif(USE_MKLDNN STREQUAL "OFF") # pass diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index 68586dfda2d82..e1a4a8a7849b1 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -52,39 +52,6 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): ) -def matmul_u8s8s32(lhs, rhs, transa=False, transb=False, **kwargs): - """Create an extern op that compute matrix mult of A and rhs with CrhsLAS - This function serves as an example on how to call external libraries. - - Parameters - ---------- - lhs: Tensor - The left matrix operand - rhs: Tensor - The right matrix operand - transa: bool - Whether transpose lhs - transb: bool - Whether transpose rhs - - Returns - ------- - C: Tensor - The result tensor. - """ - n = lhs.shape[1] if transa else lhs.shape[0] - m = rhs.shape[0] if transb else rhs.shape[1] - return te.extern( - (n, m), - [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.cblas.matmul_u8s8s32", ins[0], ins[1], outs[0], transa, transb - ), - name="C", - **kwargs - ) - - def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs): """Create an extern op that compute batched matrix mult of A and rhs with CBLAS This function serves as an example on how to call external libraries. diff --git a/python/tvm/contrib/mkl.py b/python/tvm/contrib/mkl.py new file mode 100644 index 0000000000000..175db44dd1b7f --- /dev/null +++ b/python/tvm/contrib/mkl.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""External function interface to BLAS libraries.""" +import tvm +from tvm import te + + +def matmul(lhs, rhs, transa=False, transb=False, **kwargs): + """Create an extern op that compute matrix mult of A and rhs with CrhsLAS + This function serves as an example on how to call external libraries. + + Parameters + ---------- + lhs: Tensor + The left matrix operand + rhs: Tensor + The right matrix operand + transa: bool + Whether transpose lhs + transb: bool + Whether transpose rhs + + Returns + ------- + C: Tensor + The result tensor. + """ + n = lhs.shape[1] if transa else lhs.shape[0] + m = rhs.shape[0] if transb else rhs.shape[1] + return te.extern( + (n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.mkl.matmul", ins[0], ins[1], outs[0], transa, transb + ), + name="C", + **kwargs + ) + + +def matmul_u8s8s32(lhs, rhs, transa=False, transb=False, **kwargs): + """Create an extern op that compute matrix mult of A and rhs with CrhsLAS + This function serves as an example on how to call external libraries. + + Parameters + ---------- + lhs: Tensor + The left matrix operand + rhs: Tensor + The right matrix operand + transa: bool + Whether transpose lhs + transb: bool + Whether transpose rhs + + Returns + ------- + C: Tensor + The result tensor. + """ + n = lhs.shape[1] if transa else lhs.shape[0] + m = rhs.shape[0] if transb else rhs.shape[1] + return te.extern( + (n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.mkl.matmul_u8s8s32", ins[0], ins[1], outs[0], transa, transb + ), + name="C", + **kwargs + ) + + +def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs): + """Create an extern op that compute batched matrix mult of A and rhs with mkl + This function serves as an example on how to call external libraries. + + Parameters + ---------- + lhs: Tensor + The left matrix operand + rhs: Tensor + The right matrix operand + transa: bool + Whether transpose lhs + transb: bool + Whether transpose rhs + + Returns + ------- + C: Tensor + The result tensor. + """ + b = lhs.shape[0] + n = lhs.shape[2] if transa else lhs.shape[1] + m = rhs.shape[1] if transb else rhs.shape[2] + return te.extern( + (b, n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.mkl.batch_matmul" + if not iterative + else "tvm.contrib.mkl.batch_matmul_iterative", + ins[0], + ins[1], + outs[0], + transa, + transb, + ), + name="C", + **kwargs + ) diff --git a/python/tvm/contrib/mkldnn.py b/python/tvm/contrib/mkldnn.py new file mode 100644 index 0000000000000..48ba14c13fd3a --- /dev/null +++ b/python/tvm/contrib/mkldnn.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""External function interface to BLAS libraries.""" +import tvm +from tvm import te + + +def matmul(lhs, rhs, transa=False, transb=False, **kwargs): + """Create an extern op that compute matrix mult of A and rhs with CrhsLAS + This function serves as an example on how to call external libraries. + + Parameters + ---------- + lhs: Tensor + The left matrix operand + rhs: Tensor + The right matrix operand + transa: bool + Whether transpose lhs + transb: bool + Whether transpose rhs + + Returns + ------- + C: Tensor + The result tensor. + """ + n = lhs.shape[1] if transa else lhs.shape[0] + m = rhs.shape[0] if transb else rhs.shape[1] + return te.extern( + (n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.mkl.matmul", ins[0], ins[1], outs[0], transa, transb + ), + name="C", + **kwargs + ) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 25c75b16c7ef8..f60335a4d44bc 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -181,11 +181,14 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) """ all_impls = get_valid_implementations(op, attrs, inputs, out_type, target) - best_plevel_impl = None - for impl in all_impls: - if best_plevel_impl is None or impl.plevel > best_plevel_impl.plevel: - best_plevel_impl = impl + best_plevel_impl = max(all_impls, key=lambda x: x.plevel) if not use_autotvm: + logger.info( + "Using %s for %s based on highest priority (%d)", + best_plevel_impl.name, + op.name, + best_plevel_impl.plevel, + ) outs = best_plevel_impl.compute(attrs, inputs, out_type) return best_plevel_impl, outs @@ -207,12 +210,21 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) if cfg.is_fallback: # Skip fallback config continue + logger.info( + "Implementation %s for %s has cost %.2e", impl.name, op.name, cfg.cost + ) if best_cfg is None or best_cfg.cost > cfg.cost: best_autotvm_impl = impl best_cfg = cfg autotvm.GLOBAL_SCOPE.silent = False if best_autotvm_impl: # The best autotvm implementation definitely doesn't use fallback config + logger.info( + "Using %s for %s based on lowest cost (%.2e)", + best_autotvm_impl.name, + op.name, + best_cfg.cost, + ) return best_autotvm_impl, outputs[best_autotvm_impl] # Use the implementation with highest plevel if workloads[best_plevel_impl] is not None: @@ -222,6 +234,12 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) if msg not in autotvm.task.DispatchContext.warning_messages: autotvm.task.DispatchContext.warning_messages.add(msg) autotvm_logger.warning(msg) + logger.info( + "Using %s for %s based on highest priority (%s)", + best_plevel_impl.name, + op.name, + best_plevel_impl.plevel, + ) return best_plevel_impl, outputs[best_plevel_impl] @@ -261,7 +279,6 @@ def lower_call(call, inputs, target): if not is_dyn: best_impl, outputs = select_implementation( op, call.attrs, inputs, ret_type, target) - logger.info("Use implementation %s for op %s", best_impl.name, op.name) else: # TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes. # Currently, we just use the implementation with highest plevel diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index eb5b5a5111bdf..d30b6a43984f6 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -262,15 +262,37 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" strategy = _op.OpStrategy() m, _ = inputs[0].shape + same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype + dtype = inputs[0].dtype + u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" strategy.add_implementation(wrap_compute_dense(topi.x86.dense_nopack), wrap_topi_schedule(topi.x86.schedule_dense_nopack), name="dense_nopack.x86", plevel=10) if "cblas" in target.libs: - strategy.add_implementation(wrap_compute_dense(topi.x86.dense_cblas), - wrap_topi_schedule(topi.x86.schedule_dense_cblas), - name="dense_cblas.x86", - plevel=15) + with SpecializedCondition(same_type and dtype in ["float32", "float64"]): + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_cblas), + wrap_topi_schedule(topi.x86.schedule_dense_cblas), + name="dense_cblas.x86", + plevel=13, + ) + if "mkl" in target.libs: + with SpecializedCondition(same_type and dtype in ["float32", "float64"] or u8s8s32): + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_mkl), + wrap_topi_schedule(topi.x86.schedule_dense_mkl), + name="dense_mkl.x86", + plevel=14, + ) + if "mkldnn" in target.libs: + with SpecializedCondition(same_type and dtype == "float32"): + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_mkldnn), + wrap_topi_schedule(topi.x86.schedule_dense_mkldnn), + name="dense_mkldnn.x86", + plevel=15, + ) with SpecializedCondition(m >= 16): # this implementation may not be well-optimized, so use plevel=8 for now. strategy.add_implementation(wrap_compute_dense(topi.x86.dense_pack), diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 500fcfc8c6322..c2e5b554bb863 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -22,6 +22,8 @@ from tvm import autotvm from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas +from tvm.contrib import mkl +from tvm.contrib import mkldnn from .util import get_fp32_len from .. import generic, tag @@ -220,25 +222,56 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s -@autotvm.register_topi_compute("dense_cblas.x86") -def dense_cblas(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense using cblas library""" +def dense_blas_common(cfg, data, weight, bias, out_dtype, lib): + """Compute dense using a BLAS library""" M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) cfg.add_flop(M * K * N * 2) - if data.dtype == 'uint8' and weight.dtype == 'int8' and out_dtype == 'int32': - C = cblas.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype) - elif data.dtype == 'float32': - C = cblas.matmul(data, weight, False, True) + if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype == "int32": + if not hasattr(lib, "matmul_u8s8s32"): + raise NotImplementedError( + f"Dense with {lib.__name__} for {data.dtype} is not supported " + "(matmulu8s8s32 not imlemented)" + ) + C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype) + elif data.dtype == 'float32' or data.dtype == 'float64': + C = lib.matmul(data, weight, False, True) else: - raise NotImplementedError(f"Dense with cblas for {data.dtype} is not supported") + raise NotImplementedError( + f"Dense with {lib.__name__} for {data.dtype} is not supported" + ) if bias is not None: C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) return C +@autotvm.register_topi_compute("dense_cblas.x86") +def dense_cblas(cfg, data, weight, bias=None, out_dtype=None): + """Compute dense using a cblas""" + return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas) + @autotvm.register_topi_schedule("dense_cblas.x86") def schedule_dense_cblas(_, outs): """Create schedule for dense_cblas""" return generic.schedule_extern(outs) + +@autotvm.register_topi_compute("dense_mkl.x86") +def dense_mkl(cfg, data, weight, bias=None, out_dtype=None): + """Compute dense using mkl""" + return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl) + +@autotvm.register_topi_schedule("dense_mkl.x86") +def schedule_dense_mkl(_, outs): + """Create schedule for dense_mkl""" + return generic.schedule_extern(outs) + +@autotvm.register_topi_compute("dense_mkldnn.x86") +def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None): + """Compute dense using mkldnn""" + return dense_blas_common(cfg, data, weight, bias, out_dtype, mkldnn) + +@autotvm.register_topi_schedule("dense_mkldnn.x86") +def schedule_dense_mkldnn(_, outs): + """Create schedule for dense_mkldnn""" + return generic.schedule_extern(outs) diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index e84ee1127fdb2..0bd6e6f936a37 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -24,19 +24,12 @@ #include #include -#include "gemm_common.h" - extern "C" { -#if USE_MKL_BLAS == 1 -#include -#else #include -#endif -#if USE_DNNL == 1 -#include -#endif } +#include "gemm_common.h" + namespace tvm { namespace contrib { @@ -44,48 +37,14 @@ using namespace runtime; inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } -#if USE_MKL_BLAS == 1 -inline CBLAS_OFFSET StringToOffset(const std::string offset_type) { - if (offset_type != "CblasFixOffset" && offset_type != "CblasColOffset" && - offset_type != "CblasRowOffset") { - LOG(FATAL) << "Unrecognized offset_type " << offset_type; - } - if (offset_type == "CblasFixOffset") { - return CblasFixOffset; - } else if (offset_type == "CblasColOffset") { - return CblasColOffset; - } - return CblasRowOffset; -} -#endif - inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; } -struct CblasGemmU8S8S32Op { - void operator()(bool ta, bool tb, int M, int N, int K, float alpha, const void* A, int lda, - int offset_a, const void* B, int ldb, int offset_b, float beta, int* C, int ldc, - const std::string offset_ctype, int* offset_c) { -#if USE_MKL_BLAS == 1 - cblas_gemm_s8u8s32(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), - StringToOffset(offset_ctype), M, N, K, alpha, A, lda, offset_a, B, ldb, - offset_b, beta, C, ldc, offset_c); -#else - LOG(FATAL) << "Quantized Gemm is supported with MKL Blas only"; -#endif - } -}; - struct CblasSgemmOp { typedef float TDatatype; void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, int ldb, float beta, float* C, int ldc) { -#if USE_DNNL == 1 - dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B, ldb, A, - lda, beta, C, ldc); -#else cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); -#endif } }; @@ -105,25 +64,12 @@ struct CblasSgemmBatchOp { int c_stride, int ldc) { CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); -#if USE_MKL_BLAS == 1 - std::vector A_array(batch_size); - std::vector B_array(batch_size); - std::vector C_array(batch_size); - for (int i = 0; i < batch_size; ++i) { - A_array[i] = A + i * a_stride; - B_array[i] = B + i * b_stride; - C_array[i] = C + i * c_stride; - } - cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, - B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); -#else for (int i = 0; i < batch_size; ++i) { cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); A += a_stride; B += b_stride; C += c_stride; } -#endif } }; @@ -150,25 +96,12 @@ struct CblasDgemmBatchOp { int c_stride, int ldc) { CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); -#if USE_MKL_BLAS == 1 - std::vector A_array(batch_size); - std::vector B_array(batch_size); - std::vector C_array(batch_size); - for (int i = 0; i < batch_size; ++i) { - A_array[i] = A + i * a_stride; - B_array[i] = B + i * b_stride; - C_array[i] = C + i * c_stride; - } - cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, - B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); -#else for (int i = 0; i < batch_size; ++i) { cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); A += a_stride; B += b_stride; C += c_stride; } -#endif } }; @@ -199,18 +132,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRet CallGemm(args, ret, CblasDgemmOp()); }); -// integer matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul_u8s8s32") - .set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; - CHECK(TypeMatch(A->dtype, kDLUInt, 8) && TypeMatch(B->dtype, kDLInt, 8) && - TypeMatch(C->dtype, kDLInt, 32)); - - CallU8S8S32Gemm(args, ret, CblasGemmU8S8S32Op()); - }); - TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); diff --git a/src/runtime/contrib/cblas/mkl.cc b/src/runtime/contrib/cblas/mkl.cc new file mode 100644 index 0000000000000..fa98a35854ded --- /dev/null +++ b/src/runtime/contrib/cblas/mkl.cc @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external mkl library call. + */ +#include +#include +#include + +extern "C" { +#include +} + +#include "gemm_common.h" + +namespace tvm { +namespace contrib { + +using namespace runtime; + +inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } + +inline CBLAS_OFFSET StringToOffset(const std::string offset_type) { + if (offset_type != "CblasFixOffset" && offset_type != "CblasColOffset" && + offset_type != "CblasRowOffset") { + LOG(FATAL) << "Unrecognized offset_type " << offset_type; + } + if (offset_type == "CblasFixOffset") { + return CblasFixOffset; + } else if (offset_type == "CblasColOffset") { + return CblasColOffset; + } + return CblasRowOffset; +} + +inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; } + +struct MKLGemmU8S8S32Op { + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, const void* A, int lda, + int offset_a, const void* B, int ldb, int offset_b, float beta, int* C, int ldc, + const std::string offset_ctype, int* offset_c) { + cblas_gemm_s8u8s32(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), + StringToOffset(offset_ctype), M, N, K, alpha, A, lda, offset_a, B, ldb, + offset_b, beta, C, ldc, offset_c); + } +}; + +struct MKLSgemmOp { + typedef float TDatatype; + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); + } +}; + +struct MKLDgemmOp { + typedef double TDatatype; + void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda, + double* B, int ldb, double beta, double* C, int ldc) { + cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); + } +}; + +struct MKLSgemmBatchOp { + typedef float TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector C_array(batch_size); + for (int i = 0; i < batch_size; ++i) { + A_array[i] = A + i * a_stride; + B_array[i] = B + i * b_stride; + C_array[i] = C + i * c_stride; + } + cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, + B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); + } +}; + +struct MKLSgemmBatchIterativeOp { + typedef float TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + for (int i = 0; i < batch_size; ++i) { + cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } + } +}; + +struct MKLDgemmBatchOp { + typedef double TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector C_array(batch_size); + for (int i = 0; i < batch_size; ++i) { + A_array[i] = A + i * a_stride; + B_array[i] = B + i * b_stride; + C_array[i] = C + i * c_stride; + } + cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, + B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); + } +}; + +struct MKLDgemmBatchIterativeOp { + typedef double TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + for (int i = 0; i < batch_size; ++i) { + cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } + } +}; + +// matrix multiplication for row major +TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, MKLSgemmOp()); + else + CallGemm(args, ret, MKLDgemmOp()); +}); + +// integer matrix multiplication for row major +TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + CHECK(TypeMatch(A->dtype, kDLUInt, 8) && TypeMatch(B->dtype, kDLInt, 8) && + TypeMatch(C->dtype, kDLInt, 32)); + + CallU8S8S32Gemm(args, ret, MKLGemmU8S8S32Op()); +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, MKLSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, MKLDgemmBatchOp()); + } +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul_iterative") + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, MKLSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, MKLDgemmBatchIterativeOp()); + } + }); +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/cblas/mkldnn.cc b/src/runtime/contrib/cblas/mkldnn.cc new file mode 100644 index 0000000000000..164f15a8e9ad8 --- /dev/null +++ b/src/runtime/contrib/cblas/mkldnn.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external cblas library call. + */ +#include +#include +#include + +extern "C" { +#include +} + +#include "gemm_common.h" + +namespace tvm { +namespace contrib { + +using namespace runtime; + +inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; } + +struct MKLDNNSgemmOp { + typedef float TDatatype; + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B, ldb, A, + lda, beta, C, ldc); + } +}; + +// matrix multiplication for row major +TVM_REGISTER_GLOBAL("tvm.contrib.mkldnn.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CallGemm(args, ret, MKLDNNSgemmOp()); +}); +} // namespace contrib +} // namespace tvm diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index 00ddcd3061ace..e1c1c71255365 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -20,14 +20,16 @@ import numpy as np import tvm.topi.testing from tvm.contrib import cblas +from tvm.contrib import mkl +from tvm.contrib import mkldnn -def verify_matmul_add(m, l, n, transa=False, transb=False, dtype="float32"): +def verify_matmul_add(m, l, n, lib, transa=False, transb=False, dtype="float32"): bias = te.var('bias', dtype=dtype) ashape = (l, n) if transa else (n, l) bshape = (m, l) if transb else (l, m) A = te.placeholder(ashape, name='A', dtype=dtype) B = te.placeholder(bshape, name='B', dtype=dtype) - C = cblas.matmul(A, B, transa, transb) + C = lib.matmul(A, B, transa, transb) D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D") s = te.create_schedule(D.op) @@ -42,7 +44,7 @@ def verify(target="llvm"): if not tvm.runtime.enabled(target): print("skip because %s is not enabled..." % target) return - if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + if not tvm.get_global_func(lib.__name__ + ".matmul", True): print("skip because extern function is not available") return ctx = tvm.cpu(0) @@ -57,17 +59,34 @@ def verify(target="llvm"): verify() def test_matmul_add(): - verify_matmul_add(235, 128, 1024) - verify_matmul_add(235, 128, 1024, True, False) - verify_matmul_add(235, 128, 1024, False, True) - verify_matmul_add(235, 128, 1024, True, True) - verify_matmul_add(1, 16, 4) - verify_matmul_add(1, 16, 3, True, False) - verify_matmul_add(1, 16, 3, False, False) - verify_matmul_add(1, 16, 3, True, True) + verify_matmul_add(235, 128, 1024, cblas) + verify_matmul_add(235, 128, 1024, cblas, True, False) + verify_matmul_add(235, 128, 1024, cblas, False, True) + verify_matmul_add(235, 128, 1024, cblas, True, True) + verify_matmul_add(235, 128, 1024, mkl) + verify_matmul_add(235, 128, 1024, mkl, True, False) + verify_matmul_add(235, 128, 1024, mkl, False, True) + verify_matmul_add(235, 128, 1024, mkl, True, True) + verify_matmul_add(235, 128, 1024, mkldnn) + verify_matmul_add(235, 128, 1024, mkldnn, True, False) + verify_matmul_add(235, 128, 1024, mkldnn, False, True) + verify_matmul_add(235, 128, 1024, mkldnn, True, True) + verify_matmul_add(1, 16, 4, cblas) + verify_matmul_add(1, 16, 3, cblas, True, False) + verify_matmul_add(1, 16, 3, cblas, False, False) + verify_matmul_add(1, 16, 3, cblas, True, True) + verify_matmul_add(1, 16, 4, mkl) + verify_matmul_add(1, 16, 3, mkl, True, False) + verify_matmul_add(1, 16, 3, mkl, False, False) + verify_matmul_add(1, 16, 3, mkl, True, True) + verify_matmul_add(1, 16, 4, mkldnn) + verify_matmul_add(1, 16, 3, mkldnn, True, False) + verify_matmul_add(1, 16, 3, mkldnn, False, False) + verify_matmul_add(1, 16, 3, mkldnn, True, True) def verify_quantized_matmul_add(m, l, n, transa=False, transb=False): - pytest.skip("Quantized dense is supported only for MKL. TVM GPU CI uses openblas") + if not tvm.get_global_func("tvm.contrib.mkl.matmul_u8s8s32", True): + pytest.skip("Quantized dense is supported only for MKL. TVM GPU CI uses openblas") data_dtype = "uint8" kernel_dtype = "int8" out_dtype = "int32" @@ -76,7 +95,7 @@ def verify_quantized_matmul_add(m, l, n, transa=False, transb=False): bshape = (m, l) if transb else (l, m) A = te.placeholder(ashape, name='A', dtype=data_dtype) B = te.placeholder(bshape, name='B', dtype=kernel_dtype) - C = cblas.matmul_u8s8s32(A, B, transa, transb, dtype=out_dtype) + C = mkl.matmul_u8s8s32(A, B, transa, transb, dtype=out_dtype) D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D") s = te.create_schedule(D.op) @@ -91,7 +110,7 @@ def verify(target="llvm"): if not tvm.runtime.enabled(target): print("skip because %s is not enabled..." % target) return - if not tvm.get_global_func("tvm.contrib.cblas.matmul_u8s8s32", True): + if not tvm.get_global_func("tvm.contrib.mkl.matmul_u8s8s32", True): print("skip because extern function is not available") return ctx = tvm.cpu(0) @@ -117,7 +136,7 @@ def test_quantized_matmul_add(): verify_quantized_matmul_add(1, 16, 3, False, True) verify_quantized_matmul_add(1, 16, 3, True, True) -def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype="float32"): +def verify_batch_matmul(batch, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32"): ashape = (batch, l, n) if transa else (batch, n, l) bshape = (batch, m, l) if transb else (batch, l, m) A = te.placeholder(ashape, name='A', dtype=dtype) @@ -137,7 +156,7 @@ def verify(target="llvm"): if not tvm.runtime.enabled(target): print("skip because %s is not enabled..." % target) return - if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + if not tvm.get_global_func(lib.__name__ + ".matmul", True): print("skip because extern function is not available") return ctx = tvm.cpu(0) @@ -151,16 +170,26 @@ def verify(target="llvm"): verify() def test_batch_matmul(): - verify_batch_matmul(16, 235, 128, 1024) - verify_batch_matmul(16, 235, 128, 1024, True, False) - verify_batch_matmul(16, 235, 128, 1024, False, True) - verify_batch_matmul(16, 235, 128, 1024, True, True) - verify_batch_matmul(1, 1, 16, 3) - verify_batch_matmul(1, 1, 16, 3, True, False) - verify_batch_matmul(1, 1, 16, 3, False, False) - verify_batch_matmul(1, 1, 16, 3, True, True) - verify_batch_matmul(1, 1, 16, 3, iterative=True) + verify_batch_matmul(16, 235, 128, 1024, cblas) + verify_batch_matmul(16, 235, 128, 1024, cblas, True, False) + verify_batch_matmul(16, 235, 128, 1024, cblas, False, True) + verify_batch_matmul(16, 235, 128, 1024, cblas, True, True) + verify_batch_matmul(16, 235, 128, 1024, mkl) + verify_batch_matmul(16, 235, 128, 1024, mkl, True, False) + verify_batch_matmul(16, 235, 128, 1024, mkl, False, True) + verify_batch_matmul(16, 235, 128, 1024, mkl, True, True) + verify_batch_matmul(1, 1, 16, 3, cblas) + verify_batch_matmul(1, 1, 16, 3, cblas, True, False) + verify_batch_matmul(1, 1, 16, 3, cblas, False, False) + verify_batch_matmul(1, 1, 16, 3, cblas, True, True) + verify_batch_matmul(1, 1, 16, 3, cblas, iterative=True) + verify_batch_matmul(1, 1, 16, 3, mkl) + verify_batch_matmul(1, 1, 16, 3, mkl, True, False) + verify_batch_matmul(1, 1, 16, 3, mkl, False, False) + verify_batch_matmul(1, 1, 16, 3, mkl, True, True) + verify_batch_matmul(1, 1, 16, 3, mkl, iterative=True) if __name__ == "__main__": test_matmul_add() + test_quantized_matmul_add() test_batch_matmul()