Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Contrib] Implement batch_matmul with CBLAS #3210

Merged
merged 1 commit into from
May 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cmake/modules/contrib/BLAS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl")
if(NOT IS_DIRECTORY ${USE_MKL_PATH})
set(USE_MKL_PATH /opt/intel/mkl)
endif()
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
if(APPLE)
find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
elseif(UNIX)
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
endif()
include_directories(${USE_MKL_PATH}/include)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY})
list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC})
Expand Down
55 changes: 49 additions & 6 deletions python/tvm/contrib/cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
"""External function interface to BLAS libraries."""
from __future__ import absolute_import as _abs

from .. import api as _api
from .. import intrin as _intrin
from .. import api as _api, intrin as _intrin

def matmul(lhs, rhs, transa=False, transb=False):

def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS

This function serves as an example on how to call external libraries.
Expand All @@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False):
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
(n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
),
name="C",
**kwargs
)


def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs):
"""Create an extern op that compute batched matrix mult of A and rhs with CBLAS
This function serves as an example on how to call external libraries.
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return _api.extern(
(b, n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
"tvm.contrib.cblas.batch_matmul"
if not iterative
else "tvm.contrib.cblas.batch_matmul_iterative",
ins[0],
ins[1],
outs[0],
transa,
transb,
),
name="C",
**kwargs
)
171 changes: 131 additions & 40 deletions src/contrib/cblas/cblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -21,12 +21,11 @@
* Copyright (c) 2017 by Contributors
* \file Use external cblas library call.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include "gemm_common.h"


extern "C" {
#if USE_MKL_BLAS == 1
#include <mkl_cblas.h>
Expand All @@ -40,56 +39,148 @@ namespace contrib {

using namespace runtime;

inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) {
return trans ? CblasTrans : CblasNoTrans;
}
inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }

struct CblasSgemmOp {
typedef float TDatatype;
void operator()(bool ta, bool tb,
int M, int N, int K,
float alpha, float* A, int lda,
float* B, int ldb,
float beta, float* C, int ldc) {
cblas_sgemm(CblasColMajor,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
alpha, A, lda,
B, ldb,
beta, C, ldc);
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
int ldb, float beta, float* C, int ldc) {
cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
};

struct CblasDgemmOp {
typedef double TDatatype;
void operator()(bool ta, bool tb,
int M, int N, int K,
double alpha, double* A, int lda,
double* B, int ldb,
double beta, double* C, int ldc) {
cblas_dgemm(CblasColMajor,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
alpha, A, lda,
B, ldb,
beta, C, ldc);
void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda,
double* B, int ldb, double beta, double* C, int ldc) {
cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
};

struct CblasSgemmBatchOp {
typedef float TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
#if USE_MKL_BLAS == 1
std::vector<const float*> A_array(batch_size);
std::vector<const float*> B_array(batch_size);
std::vector<float*> C_array(batch_size);
for (int i = 0; i < batch_size; ++i) {
A_array[i] = A + i * a_stride;
B_array[i] = B + i * b_stride;
C_array[i] = C + i * c_stride;
}
cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
#else
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
#endif
}
};

struct CblasSgemmBatchIterativeOp {
typedef float TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
}
};

struct CblasDgemmBatchOp {
typedef double TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
#if USE_MKL_BLAS == 1
std::vector<const double*> A_array(batch_size);
std::vector<const double*> B_array(batch_size);
std::vector<double*> C_array(batch_size);
for (int i = 0; i < batch_size; ++i) {
A_array[i] = A + i * a_stride;
B_array[i] = B + i * b_stride;
C_array[i] = C + i * c_stride;
}
cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
#else
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
#endif
}
};

struct CblasDgemmBatchIterativeOp {
typedef double TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
}
};

// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));

if (TypeMatch(A->dtype, kDLFloat, 32))
CallGemm(args, ret, CblasSgemmOp());
else
CallGemm(args, ret, CblasDgemmOp());
});

if (TypeMatch(A->dtype, kDLFloat, 32))
CallGemm(args, ret, CblasSgemmOp());
else
CallGemm(args, ret, CblasDgemmOp());
});
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32)) {
CallBatchGemm(args, ret, CblasSgemmBatchOp());
} else {
CallBatchGemm(args, ret, CblasDgemmBatchOp());
}
});

TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32)) {
CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
} else {
CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp());
}
});
} // namespace contrib
} // namespace tvm
Loading