diff --git a/cmake/modules/contrib/BLAS.cmake b/cmake/modules/contrib/BLAS.cmake index e1e151d6a9f86..a47f83771d374 100644 --- a/cmake/modules/contrib/BLAS.cmake +++ b/cmake/modules/contrib/BLAS.cmake @@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl") if(NOT IS_DIRECTORY ${USE_MKL_PATH}) set(USE_MKL_PATH /opt/intel/mkl) endif() - find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + if(APPLE) + find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + elseif(UNIX) + find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + endif() include_directories(${USE_MKL_PATH}/include) list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY}) list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC}) diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index c656fcc2b966f..7c024b7928679 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -17,10 +17,10 @@ """External function interface to BLAS libraries.""" from __future__ import absolute_import as _abs -from .. import api as _api -from .. import intrin as _intrin +from .. import api as _api, intrin as _intrin -def matmul(lhs, rhs, transa=False, transb=False): + +def matmul(lhs, rhs, transa=False, transb=False, **kwargs): """Create an extern op that compute matrix mult of A and rhs with CrhsLAS This function serves as an example on how to call external libraries. @@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False): n = lhs.shape[1] if transa else lhs.shape[0] m = rhs.shape[0] if transb else rhs.shape[1] return _api.extern( - (n, m), [lhs, rhs], + (n, m), + [lhs, rhs], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb + ), + name="C", + **kwargs + ) + + +def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs): + """Create an extern op that compute batched matrix mult of A and rhs with CBLAS + This function serves as an example on how to call external libraries. + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + Returns + ------- + C : Tensor + The result tensor. + """ + b = lhs.shape[0] + n = lhs.shape[2] if transa else lhs.shape[1] + m = rhs.shape[1] if transb else rhs.shape[2] + return _api.extern( + (b, n, m), + [lhs, rhs], lambda ins, outs: _intrin.call_packed( - "tvm.contrib.cblas.matmul", - ins[0], ins[1], outs[0], transa, transb), name="C") + "tvm.contrib.cblas.batch_matmul" + if not iterative + else "tvm.contrib.cblas.batch_matmul_iterative", + ins[0], + ins[1], + outs[0], + transa, + transb, + ), + name="C", + **kwargs + ) diff --git a/src/contrib/cblas/cblas.cc b/src/contrib/cblas/cblas.cc index 4ca043f1bcfee..19d325a77917b 100644 --- a/src/contrib/cblas/cblas.cc +++ b/src/contrib/cblas/cblas.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,12 +21,11 @@ * Copyright (c) 2017 by Contributors * \file Use external cblas library call. */ +#include #include #include -#include #include "gemm_common.h" - extern "C" { #if USE_MKL_BLAS == 1 #include @@ -40,56 +39,132 @@ namespace contrib { using namespace runtime; -inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { - return trans ? CblasTrans : CblasNoTrans; -} +inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } struct CblasSgemmOp { typedef float TDatatype; - void operator()(bool ta, bool tb, - int M, int N, int K, - float alpha, float* A, int lda, - float* B, int ldb, - float beta, float* C, int ldc) { - cblas_sgemm(CblasColMajor, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - alpha, A, lda, - B, ldb, - beta, C, ldc); + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); } }; struct CblasDgemmOp { typedef double TDatatype; - void operator()(bool ta, bool tb, - int M, int N, int K, - double alpha, double* A, int lda, - double* B, int ldb, - double beta, double* C, int ldc) { - cblas_dgemm(CblasColMajor, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - alpha, A, lda, - B, ldb, - beta, C, ldc); + void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda, + double* B, int ldb, double beta, double* C, int ldc) { + cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); } }; +struct CblasSgemmBatchOp { + typedef float TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); +#if USE_MKL_BLAS == 1 + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector C_array(batch_size); + for (int i = 0; i < batch_size; ++i) { + A_array[i] = A + i * a_stride; + B_array[i] = B + i * b_stride; + C_array[i] = C + i * c_stride; + } + cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, + B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); +#else + for (int i = 0; i < batch_size; ++i) { + cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } +#endif + } +}; + +struct CblasSgemmBatchIterativeOp { + typedef float TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + for (int i = 0; i < batch_size; ++i) { + cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } + } +}; + +struct CblasDgemmBatchOp { + typedef double TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); +#if USE_MKL_BLAS == 1 + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector C_array(batch_size); + for (int i = 0; i < batch_size; ++i) { + A_array[i] = A + i * a_stride; + B_array[i] = B + i * b_stride; + C_array[i] = C + i * c_stride; + } + cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, + B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); +#else + for (int i = 0; i < batch_size; ++i) { + cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } +#endif + } +}; // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* A = args[0]; - CHECK(TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CblasSgemmOp()); + else + CallGemm(args, ret, CblasDgemmOp()); +}); - if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CblasSgemmOp()); - else - CallGemm(args, ret, CblasDgemmOp()); - }); +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchOp()); + } +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + LOG(FATAL) << "Unhandled type"; + } +}); } // namespace contrib } // namespace tvm diff --git a/src/contrib/cblas/gemm_common.h b/src/contrib/cblas/gemm_common.h index fe38b2a675131..103b4ae236c1f 100644 --- a/src/contrib/cblas/gemm_common.h +++ b/src/contrib/cblas/gemm_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,16 +22,17 @@ * \file tvm/contrib/gemm.h * \brief Shared implementation of gemm */ -#ifndef TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ -#define TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ +#pragma once + #include +#include +#include namespace tvm { namespace contrib { using namespace runtime; - -inline int ColumnStride(DLTensor* tensor) { +inline int ColumnStride(DLTensor *tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -42,8 +43,7 @@ inline int ColumnStride(DLTensor* tensor) { } } - -inline int ElementStride(DLTensor* tensor) { +inline int ElementStride(DLTensor *tensor) { if (tensor->strides) { return std::min(tensor->strides[0], tensor->strides[1]); } else { @@ -51,29 +51,26 @@ inline int ElementStride(DLTensor* tensor) { } } - // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed(DLTensor* tensor) { +inline bool IsInPlaceTransposed(DLTensor *tensor) { return tensor->strides && (tensor->strides[1] > tensor->strides[0]); } - -inline int RowCount(DLTensor* tensor, bool trans) { +inline int RowCount(DLTensor *tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } - -inline int ColumnCount(DLTensor* tensor, bool trans) { +inline int ColumnCount(DLTensor *tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. -template +template inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; + DLTensor *A = args[0]; + DLTensor *B = args[1]; + DLTensor *C = args[2]; bool transa = args[3]; bool transb = args[4]; int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8; @@ -96,25 +93,88 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - op(transb, - transa, - ColumnCount(B, transb), - RowCount(A, transa), - ColumnCount(A, transa), - static_cast(alpha), - reinterpret_cast(static_cast(B->data) - + B->byte_offset), + op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), + ColumnCount(A, transa), static_cast(alpha), + reinterpret_cast( + static_cast(B->data) + B->byte_offset), ColumnStride(B), - reinterpret_cast(static_cast(A->data) - + A->byte_offset), - ColumnStride(A), - static_cast(beta), - reinterpret_cast(static_cast(C->data) - + C->byte_offset), + reinterpret_cast( + static_cast(A->data) + A->byte_offset), + ColumnStride(A), static_cast(beta), + reinterpret_cast( + static_cast(C->data) + C->byte_offset), ColumnStride(C)); } -} // namespace contrib -} // namespace tvm +inline int ColumnStride3D(DLTensor *tensor) { + // If the tensor itself is transposed then it will have strides + // backward from what we expect. Regardless, the max of the strides + // (the other stride is 1) is the column stride. + if (tensor->strides) { + return std::max(tensor->strides[1], tensor->strides[2]); + } else { + return tensor->shape[2]; + } +} +inline int ElementStride3D(DLTensor *tensor) { + if (tensor->strides) { + return std::min(tensor->strides[1], tensor->strides[2]); + } else { + return 1; + } +} +// Reversed strides indicates an in-place transpose operation. +inline bool IsInPlaceTransposed3D(DLTensor *tensor) { + return tensor->strides && (tensor->strides[2] > tensor->strides[1]); +} +inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; } +inline int RowCount3D(DLTensor *tensor, bool trans) { + return tensor->shape[trans ? 2 : 1]; +} +inline int ColumnCount3D(DLTensor *tensor, bool trans) { + return tensor->shape[trans ? 1 : 2]; +} +template +inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { + using DType = typename TBatchGemmOp::TDatatype; + DLTensor *A = args[0]; + DLTensor *B = args[1]; + DLTensor *C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + int bit_depth = sizeof(DType) * 8; + CHECK_EQ(A->ndim, 3); + CHECK_EQ(B->ndim, 3); + CHECK_EQ(C->ndim, 3); + int batch_size = BatchCount3D(A); + CHECK_EQ(BatchCount3D(B), batch_size); + CHECK_EQ(BatchCount3D(C), batch_size); + CHECK_EQ(ElementStride(A), 1); + CHECK_EQ(ElementStride(B), 1); + CHECK_EQ(ElementStride(C), 1); + // C can never be transposed. + CHECK(!IsInPlaceTransposed3D(C)); + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed3D(A) ? !transa : transa; + transb = IsInPlaceTransposed3D(B) ? !transb : transb; + CHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); + CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + double alpha = args.size() > 5 ? args[5] : 1.0; + double beta = args.size() > 6 ? args[6] : 0.0; + const int A_size = A->shape[1] * A->shape[2]; + const int B_size = B->shape[1] * B->shape[2]; + const int C_size = C->shape[1] * C->shape[2]; + DType *A_data = reinterpret_cast( + static_cast(A->data) + A->byte_offset); + DType *B_data = reinterpret_cast( + static_cast(B->data) + B->byte_offset); + DType *C_data = reinterpret_cast( + static_cast(C->data) + C->byte_offset); + op(batch_size, transb, transa, ColumnCount3D(B, transb), + RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast(alpha), + B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), + static_cast(beta), C_data, C_size, ColumnStride3D(C)); +} -#endif // TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ +} // namespace contrib +} // namespace tvm diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index 6705328ee50a5..a56376a61d9be 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -16,19 +16,26 @@ # under the License. import tvm import numpy as np +import topi.testing from tvm.contrib import cblas -def test_matmul_add(): - n = 1024 - l = 128 - m = 235 +def test_matmul_add(m, l, n, transa=False, transb=False): bias = tvm.var('bias', dtype=tvm.float32) - A = tvm.placeholder((n, l), name='A') - B = tvm.placeholder((l, m), name='B') - C = cblas.matmul(A, B) + ashape = (l, n) if transa else (n, l) + bshape = (m, l) if transb else (l, m) + A = tvm.placeholder(ashape, name='A') + B = tvm.placeholder(bshape, name='B') + C = cblas.matmul(A, B, transa, transb) D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D") s = tvm.create_schedule(D.op) + def get_numpy(a, b, bb, transa, transb): + if transa: + a = a.transpose() + if transb: + b = b.transpose() + return np.dot(a, b) + bb + def verify(target="llvm"): if not tvm.module.enabled(target): print("skip because %s is not enabled..." % target) @@ -38,15 +45,66 @@ def verify(target="llvm"): return ctx = tvm.cpu(0) f = tvm.build(s, [A, B, D, bias], target) - a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx) d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) bb = 10.0 f(a, b, d, bb) tvm.testing.assert_allclose( - d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5) + d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5) + verify() + + +def test_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False): + ashape = (batch, l, n) if transa else (batch, n, l) + bshape = (batch, m, l) if transb else (batch, l, m) + A = tvm.placeholder(ashape, name='A') + B = tvm.placeholder(bshape, name='B') + C = cblas.batch_matmul(A, B, transa, transb) + D = tvm.compute(C.shape, lambda k, i, j: C[k, i,j], name="D") + s = tvm.create_schedule(D.op) + + def get_numpy(a, b, transa, transb): + if transa: + a = a.transpose(0, 2, 1) + if not transb: + b = b.transpose(0, 2, 1) + return topi.testing.batch_matmul(a, b) + + def verify(target="llvm"): + if not tvm.module.enabled(target): + print("skip because %s is not enabled..." % target) + return + if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + print("skip because extern function is not available") + return + ctx = tvm.cpu(0) + f = tvm.build(s, [A, B, D], target) + a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx) + d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx) + f(a, b, d) + tvm.testing.assert_allclose( + d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5) verify() if __name__ == "__main__": - test_matmul_add() + test_matmul_add(235, 128, 1024) + test_matmul_add(235, 128, 1024, True, False) + test_matmul_add(235, 128, 1024, False, True) + test_matmul_add(235, 128, 1024, True, True) + test_matmul_add(1, 16, 4) + test_matmul_add(1, 16, 3, True, False) + test_matmul_add(1, 16, 3, False, False) + test_matmul_add(1, 16, 3, True, True) + + test_batch_matmul(16, 235, 128, 1024) + test_batch_matmul(16, 235, 128, 1024, True, False) + test_batch_matmul(16, 235, 128, 1024, False, True) + test_batch_matmul(16, 235, 128, 1024, True, True) + test_batch_matmul(1, 1, 16, 3) + test_batch_matmul(1, 1, 16, 3, True, False) + test_batch_matmul(1, 1, 16, 3, False, False) + test_batch_matmul(1, 1, 16, 3, True, True) + test_batch_matmul(1, 1, 16, 3, iterative=True)