Skip to content

Commit

Permalink
[Topi, x86] Using MKL blas for quantized dense (#6115)
Browse files Browse the repository at this point in the history
* [Topi, x86] Using MKL blas for quantized dense

* Typo

* CBLAS_OFFSET only available for MKL

* Skipping tests as GPU CI uses Openblas

* Retrigger

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
anijain2305 and Ubuntu authored Jul 28, 2020
1 parent 1e9e4b9 commit 8cd53e0
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 1 deletion.
33 changes: 33 additions & 0 deletions python/tvm/contrib/cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,39 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
)


def matmul_u8s8s32(lhs, rhs, transa=False, transb=False, **kwargs):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to call external libraries.
Parameters
----------
lhs: Tensor
The left matrix operand
rhs: Tensor
The right matrix operand
transa: bool
Whether transpose lhs
transb: bool
Whether transpose rhs
Returns
-------
C: Tensor
The result tensor.
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return te.extern(
(n, m),
[lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cblas.matmul_u8s8s32", ins[0], ins[1], outs[0], transa, transb
),
name="C",
**kwargs
)


def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs):
"""Create an extern op that compute batched matrix mult of A and rhs with CBLAS
This function serves as an example on how to call external libraries.
Expand Down
41 changes: 41 additions & 0 deletions src/runtime/contrib/cblas/cblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,37 @@ using namespace runtime;

inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }

#if USE_MKL_BLAS == 1
inline CBLAS_OFFSET StringToOffset(const std::string offset_type) {
if (offset_type != "CblasFixOffset" && offset_type != "CblasColOffset" &&
offset_type != "CblasRowOffset") {
LOG(FATAL) << "Unrecognized offset_type " << offset_type;
}
if (offset_type == "CblasFixOffset") {
return CblasFixOffset;
} else if (offset_type == "CblasColOffset") {
return CblasColOffset;
}
return CblasRowOffset;
}
#endif

inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; }

struct CblasGemmU8S8S32Op {
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, const void* A, int lda,
int offset_a, const void* B, int ldb, int offset_b, float beta, int* C, int ldc,
const std::string offset_ctype, int* offset_c) {
#if USE_MKL_BLAS == 1
cblas_gemm_s8u8s32(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb),
StringToOffset(offset_ctype), M, N, K, alpha, A, lda, offset_a, B, ldb,
offset_b, beta, C, ldc, offset_c);
#else
LOG(FATAL) << "Quantized Gemm is supported with MKL Blas only";
#endif
}
};

struct CblasSgemmOp {
typedef float TDatatype;
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
Expand Down Expand Up @@ -170,6 +199,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRet
CallGemm(args, ret, CblasDgemmOp());
});

// integer matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul_u8s8s32")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
CHECK(TypeMatch(A->dtype, kDLUInt, 8) && TypeMatch(B->dtype, kDLInt, 8) &&
TypeMatch(C->dtype, kDLInt, 32));

CallU8S8S32Gemm(args, ret, CblasGemmU8S8S32Op());
});

TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
Expand Down
48 changes: 48 additions & 0 deletions src/runtime/contrib/cblas/gemm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/registry.h>

#include <algorithm>
#include <string>

namespace tvm {
namespace contrib {
Expand Down Expand Up @@ -99,6 +100,53 @@ inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) {
ColumnStride(C));
}

// Call a column major blas. Note that data is stored in tvm as row
// major, so this we switch the arguments.
template <typename TGemmOp>
inline void CallU8S8S32Gemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];

// Set the sgemm attributes. Currently, support is limited to CblasFixOffset with all offsets
// equal to 0. This is sufficient for relay dense.
std::string offset_ctype = "CblasFixOffset";
int16_t offset_a = 0;
int16_t offset_b = 0;
int offset_c[1];
offset_c[0] = 0;

CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);

CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);

// C can never be transposed.
CHECK(!IsInPlaceTransposed(C));

// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;

CHECK(TypeMatch(A->dtype, kDLUInt, 8));
CHECK(TypeMatch(B->dtype, kDLInt, 8));
CHECK(TypeMatch(C->dtype, kDLInt, 32));
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa),
static_cast<float>(alpha),
reinterpret_cast<void*>(static_cast<char*>(B->data) + B->byte_offset), ColumnStride(B),
offset_b, reinterpret_cast<void*>(static_cast<char*>(A->data) + A->byte_offset),
ColumnStride(A), offset_a, static_cast<float>(beta),
reinterpret_cast<int*>(static_cast<char*>(C->data) + C->byte_offset), ColumnStride(C),
offset_ctype, offset_c);
}

inline int ColumnStride3D(DLTensor* tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
Expand Down
52 changes: 52 additions & 0 deletions tests/python/contrib/test_cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
from tvm import te
import numpy as np
Expand Down Expand Up @@ -65,6 +66,57 @@ def test_matmul_add():
verify_matmul_add(1, 16, 3, False, False)
verify_matmul_add(1, 16, 3, True, True)

def verify_quantized_matmul_add(m, l, n, transa=False, transb=False):
pytest.skip("Quantized dense is supported only for MKL. TVM GPU CI uses openblas")
data_dtype = "uint8"
kernel_dtype = "int8"
out_dtype = "int32"
bias = te.var('bias', dtype=out_dtype)
ashape = (l, n) if transa else (n, l)
bshape = (m, l) if transb else (l, m)
A = te.placeholder(ashape, name='A', dtype=data_dtype)
B = te.placeholder(bshape, name='B', dtype=kernel_dtype)
C = cblas.matmul_u8s8s32(A, B, transa, transb, dtype=out_dtype)
D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = te.create_schedule(D.op)

def get_numpy(a, b, bb, transa, transb):
if transa:
a = a.transpose()
if transb:
b = b.transpose()
return np.dot(a, b) + bb

def verify(target="llvm"):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cblas.matmul_u8s8s32", True):
print("skip because extern function is not available")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.randint(low=0, high=50, size=ashape).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.randint(low=0, high=50, size=bshape).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10
f(a, b, d, bb)
tvm.testing.assert_allclose(
d.asnumpy(),
get_numpy(a.asnumpy().astype('int32'), b.asnumpy().astype('int32'), bb, transa, transb),
rtol=1e-5)
verify()

def test_quantized_matmul_add():
verify_quantized_matmul_add(235, 128, 1024)
verify_quantized_matmul_add(235, 128, 1024, True, False)
verify_quantized_matmul_add(235, 128, 1024, False, True)
verify_quantized_matmul_add(235, 128, 1024, True, True)
verify_quantized_matmul_add(1, 16, 4)
verify_quantized_matmul_add(1, 16, 3, True, False)
verify_quantized_matmul_add(1, 16, 3, False, True)
verify_quantized_matmul_add(1, 16, 3, True, True)

def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype="float32"):
ashape = (batch, l, n) if transa else (batch, n, l)
bshape = (batch, m, l) if transb else (batch, l, m)
Expand Down
8 changes: 7 additions & 1 deletion topi/python/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,13 @@ def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
cfg.add_flop(M * K * N * 2)
C = cblas.matmul(data, weight, False, True)
if data.dtype == 'uint8' and weight.dtype == 'int8' and out_dtype == 'int32':
C = cblas.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
elif data.dtype == 'float32':
C = cblas.matmul(data, weight, False, True)
else:
raise NotImplementedError(f"Dense with cblas for {data.dtype} is not supported")

if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
Expand Down

0 comments on commit 8cd53e0

Please sign in to comment.