Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Topi, x86] Using MKL blas for quantized dense #6115

Merged
merged 5 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions python/tvm/contrib/cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,39 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
)


def matmul_u8s8s32(lhs, rhs, transa=False, transb=False, **kwargs):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to call external libraries.
Parameters
----------
lhs: Tensor
The left matrix operand
rhs: Tensor
The right matrix operand
transa: bool
Whether transpose lhs
transb: bool
Whether transpose rhs
Returns
-------
C: Tensor
The result tensor.
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return te.extern(
(n, m),
[lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cblas.matmul_u8s8s32", ins[0], ins[1], outs[0], transa, transb
),
name="C",
**kwargs
)


def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs):
"""Create an extern op that compute batched matrix mult of A and rhs with CBLAS
This function serves as an example on how to call external libraries.
Expand Down
41 changes: 41 additions & 0 deletions src/runtime/contrib/cblas/cblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,37 @@ using namespace runtime;

inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }

#if USE_MKL_BLAS == 1
inline CBLAS_OFFSET StringToOffset(const std::string offset_type) {
if (offset_type != "CblasFixOffset" && offset_type != "CblasColOffset" &&
offset_type != "CblasRowOffset") {
LOG(FATAL) << "Unrecognized offset_type " << offset_type;
}
if (offset_type == "CblasFixOffset") {
return CblasFixOffset;
} else if (offset_type == "CblasColOffset") {
return CblasColOffset;
}
return CblasRowOffset;
}
#endif

inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; }

struct CblasGemmU8S8S32Op {
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, const void* A, int lda,
int offset_a, const void* B, int ldb, int offset_b, float beta, int* C, int ldc,
const std::string offset_ctype, int* offset_c) {
#if USE_MKL_BLAS == 1
cblas_gemm_s8u8s32(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb),
StringToOffset(offset_ctype), M, N, K, alpha, A, lda, offset_a, B, ldb,
offset_b, beta, C, ldc, offset_c);
#else
LOG(FATAL) << "Quantized Gemm is supported with MKL Blas only";
#endif
}
};

struct CblasSgemmOp {
typedef float TDatatype;
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
Expand Down Expand Up @@ -170,6 +199,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRet
CallGemm(args, ret, CblasDgemmOp());
});

// integer matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul_u8s8s32")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
CHECK(TypeMatch(A->dtype, kDLUInt, 8) && TypeMatch(B->dtype, kDLInt, 8) &&
TypeMatch(C->dtype, kDLInt, 32));

CallU8S8S32Gemm(args, ret, CblasGemmU8S8S32Op());
});

TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
Expand Down
48 changes: 48 additions & 0 deletions src/runtime/contrib/cblas/gemm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/registry.h>

#include <algorithm>
#include <string>

namespace tvm {
namespace contrib {
Expand Down Expand Up @@ -99,6 +100,53 @@ inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) {
ColumnStride(C));
}

// Call a column major blas. Note that data is stored in tvm as row
// major, so this we switch the arguments.
template <typename TGemmOp>
inline void CallU8S8S32Gemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];

// Set the sgemm attributes. Currently, support is limited to CblasFixOffset with all offsets
// equal to 0. This is sufficient for relay dense.
std::string offset_ctype = "CblasFixOffset";
int16_t offset_a = 0;
int16_t offset_b = 0;
int offset_c[1];
offset_c[0] = 0;

CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);

CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);

// C can never be transposed.
CHECK(!IsInPlaceTransposed(C));

// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;

CHECK(TypeMatch(A->dtype, kDLUInt, 8));
CHECK(TypeMatch(B->dtype, kDLInt, 8));
CHECK(TypeMatch(C->dtype, kDLInt, 32));
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa),
static_cast<float>(alpha),
reinterpret_cast<void*>(static_cast<char*>(B->data) + B->byte_offset), ColumnStride(B),
offset_b, reinterpret_cast<void*>(static_cast<char*>(A->data) + A->byte_offset),
ColumnStride(A), offset_a, static_cast<float>(beta),
reinterpret_cast<int*>(static_cast<char*>(C->data) + C->byte_offset), ColumnStride(C),
offset_ctype, offset_c);
}

inline int ColumnStride3D(DLTensor* tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
Expand Down
52 changes: 52 additions & 0 deletions tests/python/contrib/test_cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
from tvm import te
import numpy as np
Expand Down Expand Up @@ -65,6 +66,57 @@ def test_matmul_add():
verify_matmul_add(1, 16, 3, False, False)
verify_matmul_add(1, 16, 3, True, True)

def verify_quantized_matmul_add(m, l, n, transa=False, transb=False):
pytest.skip("Quantized dense is supported only for MKL. TVM GPU CI uses openblas")
data_dtype = "uint8"
kernel_dtype = "int8"
out_dtype = "int32"
bias = te.var('bias', dtype=out_dtype)
ashape = (l, n) if transa else (n, l)
bshape = (m, l) if transb else (l, m)
A = te.placeholder(ashape, name='A', dtype=data_dtype)
B = te.placeholder(bshape, name='B', dtype=kernel_dtype)
C = cblas.matmul_u8s8s32(A, B, transa, transb, dtype=out_dtype)
D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = te.create_schedule(D.op)

def get_numpy(a, b, bb, transa, transb):
if transa:
a = a.transpose()
if transb:
b = b.transpose()
return np.dot(a, b) + bb

def verify(target="llvm"):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cblas.matmul_u8s8s32", True):
print("skip because extern function is not available")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.randint(low=0, high=50, size=ashape).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.randint(low=0, high=50, size=bshape).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10
f(a, b, d, bb)
tvm.testing.assert_allclose(
d.asnumpy(),
get_numpy(a.asnumpy().astype('int32'), b.asnumpy().astype('int32'), bb, transa, transb),
rtol=1e-5)
verify()

def test_quantized_matmul_add():
verify_quantized_matmul_add(235, 128, 1024)
verify_quantized_matmul_add(235, 128, 1024, True, False)
verify_quantized_matmul_add(235, 128, 1024, False, True)
verify_quantized_matmul_add(235, 128, 1024, True, True)
verify_quantized_matmul_add(1, 16, 4)
verify_quantized_matmul_add(1, 16, 3, True, False)
verify_quantized_matmul_add(1, 16, 3, False, True)
verify_quantized_matmul_add(1, 16, 3, True, True)

def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype="float32"):
ashape = (batch, l, n) if transa else (batch, n, l)
bshape = (batch, m, l) if transb else (batch, l, m)
Expand Down
8 changes: 7 additions & 1 deletion topi/python/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,13 @@ def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
cfg.add_flop(M * K * N * 2)
C = cblas.matmul(data, weight, False, True)
if data.dtype == 'uint8' and weight.dtype == 'int8' and out_dtype == 'int32':
C = cblas.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
elif data.dtype == 'float32':
C = cblas.matmul(data, weight, False, True)
else:
raise NotImplementedError(f"Dense with cblas for {data.dtype} is not supported")

if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
Expand Down