Skip to content

Commit

Permalink
[Perf] Add CublasLt extern support for better Igemm performance (#4550)
Browse files Browse the repository at this point in the history
* cublaslt added

* fix lint

* address comments

* address more comments

* Trigger CI

* Trigger CI
  • Loading branch information
Laurawly authored and masahi committed Dec 29, 2019
1 parent f727810 commit fadea92
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 3 deletions.
51 changes: 51 additions & 0 deletions python/tvm/contrib/cublaslt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""External function interface to cuBLASlt libraries."""
from __future__ import absolute_import as _abs

from .. import api as _api
from .. import intrin as _intrin

def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
if n == 0:
n = lhs.shape[1] if transa else lhs.shape[0]
if m == 0:
m = rhs.shape[0] if transb else rhs.shape[1]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublaslt.matmul",
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
111 changes: 109 additions & 2 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,98 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s
}
}

int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}

#if CUDART_VERSION >= 10010
inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) {
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
int M = ColumnCount(B, transb);
int N = RowCount(A, transa);
int K = ColumnCount(A, transa);
int N_out = ColumnCount(C, false);
int m = M;
int n = m;
int k = m;
int lda = M * K / (roundoff(K, 32) / 32);
int ldb = K * N / (roundoff(K, 32) / 32);
int ldc = M * N_out / (roundoff(N_out, 32) / 32);
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);

CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);

CHECK(TypeEqual(A->dtype, B->dtype));
CHECK(TypeMatch(A->dtype, kDLInt, 8));
CHECK(TypeMatch(C->dtype, kDLInt, 32));

CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type";
int32_t alpha = args.size() > 5 ? args[5] : 1;
int32_t beta = args.size() > 6 ? args[6] : 0;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
auto A_data = reinterpret_cast<void*>(static_cast<char*>(A->data) + A->byte_offset);
auto B_data = reinterpret_cast<void*>(static_cast<char*>(B->data) + B->byte_offset);
auto C_data = reinterpret_cast<void*>(static_cast<char*>(C->data) + C->byte_offset);

cublasOperation_t opTranspose = CUBLAS_OP_T;
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
cublasLtMatmulDesc_t operationDesc = nullptr;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose)));
cublasOperation_t opTransA = BooleanToTranspose(transa);
cublasOperation_t opTransB = BooleanToTranspose(transb);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB)));
// Create descriptors for the original matrices
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(
&Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k ,
opTransA == CUBLAS_OP_N ? k : m, lda));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(
&Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n ,
opTransB == CUBLAS_OP_N ? n : k, ldb));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));

CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)));

CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl,
operationDesc,
&alpha,
B_data,
Adesc,
A_data,
Bdesc,
&beta,
C_data,
Cdesc,
C_data,
Cdesc,
NULL,
NULL,
0,
0));
}
#endif

inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
DLTensor *A = args[0];
DLTensor *B = args[1];
Expand Down Expand Up @@ -342,12 +434,27 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
}
});

TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
#if CUDART_VERSION >= 10010
TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
DLTensor* C = args[2];

CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();

TryEnableTensorCore(entry_ptr->handle);

CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n";
cublasLtHandle_t ltHandle;
CHECK_CUBLAS_ERROR(cublasLtCreate(&ltHandle));
CallLtIgemm(args, ret, ltHandle);
CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle));
});
#endif // CUDART_VERSION >= 10010

TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
DLTensor* C = args[2];

CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();

Expand Down
6 changes: 6 additions & 0 deletions src/runtime/contrib/cublas/cublas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
#include <dmlc/logging.h>
#include <dlpack/dlpack.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <cstdint>
#if CUDART_VERSION >= 10010
#include <cublasLt.h>
#endif // CUDART_VERSION >= 10010

namespace tvm {
namespace contrib {
Expand Down
65 changes: 64 additions & 1 deletion tests/python/contrib/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
import numpy as np
from tvm.contrib import cublas
from tvm.contrib import cublaslt

def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
n = 1024
Expand Down Expand Up @@ -44,6 +45,64 @@ def verify(target="cuda"):
c.asnumpy(), np.dot(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)), rtol=rtol)
verify()

def roundoff(v, d):
return int(np.floor((v + d - 1) / d) * d)

def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
n = 1024
l = 1024
m = 1024
L = roundoff(l, 32)
N = roundoff(n, 8)
N_out = roundoff(n, 32)

A = tvm.placeholder((N, L), name='A', dtype=in_dtype)
B = tvm.placeholder((m, L), name='B', dtype=in_dtype)
# C has CUBLASLT_ORDER_COL32 layout, thus a different shape
C = cublaslt.matmul(A, B, False, True, m, N_out, dtype=out_dtype)
s = tvm.create_schedule(C.op)

def verify(target="cuda"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True):
print("skip because extern function is not available")
return
ctx = tvm.gpu(0)
f = tvm.build(s, [A, B, C], target)
a_old = np.random.uniform(0, 128, size=(n, l))
b_old = np.random.uniform(0, 128, size=(l, m))

# Transform a to become CUBLASLT_ORDER_COL4_4R2_8C layout
a_new = np.hstack((a_old.astype(A.dtype), np.zeros([n, L-l])))
a_new = np.vstack((a_new.astype(A.dtype), np.zeros([N-n, L])))
a_even = np.vsplit(a_new[::2], N / 8)
a_odd = np.vsplit(a_new[1::2], N / 8)
a_new = [None]*(len(a_even) + len(a_odd))
a_new[::2] = a_even
a_new[1::2] = a_odd
a_new = np.vstack(a_new)
a_new = np.vstack(np.vstack(np.vstack(np.hsplit(i, 8)).reshape([4, 32]) for i in np.vsplit(j, N/4)) for j in np.hsplit(a_new, L/32))
a_new = a_new.reshape([N, L])
# Transform b to become CUBLASLT_ORDER_COL32 layout
b_new = np.vstack(np.hsplit(np.hstack((b_old.T.astype(B.dtype), np.zeros([m, L - l]))), L / 32))
b_new = b_new.reshape([m, L])

a = tvm.nd.array(a_new.astype(A.dtype), ctx)
b = tvm.nd.array(b_new.astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((m, N_out), dtype=C.dtype), ctx)
f(a, b, c)
# Transform output c from layout CUBLASLT_ORDER_COL32 to row major layout
c_out = c.asnumpy()
c_out = c_out.reshape([int(m * N_out / 32), 32])
c_out = np.hstack(np.vsplit(c_out, int(N_out / 32)))
c_out = c_out[:, :n]
c_out = c_out.T
tvm.testing.assert_allclose(
c_out, np.dot(a_old.astype(C.dtype), b_old.astype(C.dtype)), rtol=rtol)
verify()

def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
j = 16
n = 1024
Expand Down Expand Up @@ -73,11 +132,14 @@ def verify(target="cuda"):
verify()

def test_matmul_add():
verify_matmul_add('float', 'float')
verify_matmul_add('float', 'float', rtol=1e-3)
verify_matmul_add('float16', 'float')
verify_matmul_add('float16', 'float16', rtol=1e-2)
verify_matmul_add('int8', 'int32')

def test_matmul_add_igemm():
verify_matmul_add_igemm('int8', 'int32')

def test_batch_matmul():
verify_batch_matmul('float', 'float')
verify_batch_matmul('float16', 'float')
Expand All @@ -86,4 +148,5 @@ def test_batch_matmul():
if __name__ == "__main__":
test_matmul_add()
test_batch_matmul()
test_matmul_add_igemm()

0 comments on commit fadea92

Please sign in to comment.