diff --git a/python/tvm/contrib/cublaslt.py b/python/tvm/contrib/cublaslt.py new file mode 100644 index 000000000000..5470fd0b4c18 --- /dev/null +++ b/python/tvm/contrib/cublaslt.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""External function interface to cuBLASlt libraries.""" +from __future__ import absolute_import as _abs + +from .. import api as _api +from .. import intrin as _intrin + +def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None): + """Create an extern op that compute matrix mult of A and rhs with cuBLAS + + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + + Returns + ------- + C : Tensor + The result tensor. + """ + if n == 0: + n = lhs.shape[1] if transa else lhs.shape[0] + if m == 0: + m = rhs.shape[0] if transb else rhs.shape[1] + dtype = dtype if dtype is not None else lhs.dtype + return _api.extern( + (n, m), [lhs, rhs], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.cublaslt.matmul", + ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C") diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index bbb2d2e952cc..2cb677729654 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -181,6 +181,98 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s } } +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + +#if CUDART_VERSION >= 10010 +inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { + DLTensor *A = args[0]; + DLTensor *B = args[1]; + DLTensor *C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + int M = ColumnCount(B, transb); + int N = RowCount(A, transa); + int K = ColumnCount(A, transa); + int N_out = ColumnCount(C, false); + int m = M; + int n = m; + int k = m; + int lda = M * K / (roundoff(K, 32) / 32); + int ldb = K * N / (roundoff(K, 32) / 32); + int ldc = M * N_out / (roundoff(N_out, 32) / 32); + CHECK_EQ(A->ndim, 2); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 2); + + CHECK_EQ(ElementStride(A), 1); + CHECK_EQ(ElementStride(B), 1); + CHECK_EQ(ElementStride(C), 1); + + CHECK(TypeEqual(A->dtype, B->dtype)); + CHECK(TypeMatch(A->dtype, kDLInt, 8)); + CHECK(TypeMatch(C->dtype, kDLInt, 32)); + + CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; + int32_t alpha = args.size() > 5 ? args[5] : 1; + int32_t beta = args.size() > 6 ? args[6] : 0; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + cublasOperation_t opTranspose = CUBLAS_OP_T; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtMatmulDesc_t operationDesc = nullptr; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose))); + cublasOperation_t opTransA = BooleanToTranspose(transa); + cublasOperation_t opTransB = BooleanToTranspose(transb); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB))); + // Create descriptors for the original matrices + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( + &Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k , + opTransA == CUBLAS_OP_N ? k : m, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( + &Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n , + opTransB == CUBLAS_OP_N ? n : k, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); + + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); + + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, + operationDesc, + &alpha, + B_data, + Adesc, + A_data, + Bdesc, + &beta, + C_data, + Cdesc, + C_data, + Cdesc, + NULL, + NULL, + 0, + 0)); +} +#endif + inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { DLTensor *A = args[0]; DLTensor *B = args[1]; @@ -342,12 +434,27 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") } }); -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") +#if CUDART_VERSION >= 10010 +TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") .set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; - DLTensor* C = args[2]; + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + TryEnableTensorCore(entry_ptr->handle); + + CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; + cublasLtHandle_t ltHandle; + CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); + CallLtIgemm(args, ret, ltHandle); + CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); +}); +#endif // CUDART_VERSION >= 10010 + +TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 17e123219089..2e553e28493b 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -27,6 +27,12 @@ #include #include #include +#include +#include +#include +#if CUDART_VERSION >= 10010 +#include +#endif // CUDART_VERSION >= 10010 namespace tvm { namespace contrib { diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index 85268b95a7a8..4d4789663a9f 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -17,6 +17,7 @@ import tvm import numpy as np from tvm.contrib import cublas +from tvm.contrib import cublaslt def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5): n = 1024 @@ -44,6 +45,64 @@ def verify(target="cuda"): c.asnumpy(), np.dot(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)), rtol=rtol) verify() +def roundoff(v, d): + return int(np.floor((v + d - 1) / d) * d) + +def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5): + n = 1024 + l = 1024 + m = 1024 + L = roundoff(l, 32) + N = roundoff(n, 8) + N_out = roundoff(n, 32) + + A = tvm.placeholder((N, L), name='A', dtype=in_dtype) + B = tvm.placeholder((m, L), name='B', dtype=in_dtype) + # C has CUBLASLT_ORDER_COL32 layout, thus a different shape + C = cublaslt.matmul(A, B, False, True, m, N_out, dtype=out_dtype) + s = tvm.create_schedule(C.op) + + def verify(target="cuda"): + if not tvm.module.enabled(target): + print("skip because %s is not enabled..." % target) + return + if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True): + print("skip because extern function is not available") + return + ctx = tvm.gpu(0) + f = tvm.build(s, [A, B, C], target) + a_old = np.random.uniform(0, 128, size=(n, l)) + b_old = np.random.uniform(0, 128, size=(l, m)) + + # Transform a to become CUBLASLT_ORDER_COL4_4R2_8C layout + a_new = np.hstack((a_old.astype(A.dtype), np.zeros([n, L-l]))) + a_new = np.vstack((a_new.astype(A.dtype), np.zeros([N-n, L]))) + a_even = np.vsplit(a_new[::2], N / 8) + a_odd = np.vsplit(a_new[1::2], N / 8) + a_new = [None]*(len(a_even) + len(a_odd)) + a_new[::2] = a_even + a_new[1::2] = a_odd + a_new = np.vstack(a_new) + a_new = np.vstack(np.vstack(np.vstack(np.hsplit(i, 8)).reshape([4, 32]) for i in np.vsplit(j, N/4)) for j in np.hsplit(a_new, L/32)) + a_new = a_new.reshape([N, L]) + # Transform b to become CUBLASLT_ORDER_COL32 layout + b_new = np.vstack(np.hsplit(np.hstack((b_old.T.astype(B.dtype), np.zeros([m, L - l]))), L / 32)) + b_new = b_new.reshape([m, L]) + + a = tvm.nd.array(a_new.astype(A.dtype), ctx) + b = tvm.nd.array(b_new.astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((m, N_out), dtype=C.dtype), ctx) + f(a, b, c) + # Transform output c from layout CUBLASLT_ORDER_COL32 to row major layout + c_out = c.asnumpy() + c_out = c_out.reshape([int(m * N_out / 32), 32]) + c_out = np.hstack(np.vsplit(c_out, int(N_out / 32))) + c_out = c_out[:, :n] + c_out = c_out.T + tvm.testing.assert_allclose( + c_out, np.dot(a_old.astype(C.dtype), b_old.astype(C.dtype)), rtol=rtol) + verify() + def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5): j = 16 n = 1024 @@ -73,11 +132,14 @@ def verify(target="cuda"): verify() def test_matmul_add(): - verify_matmul_add('float', 'float') + verify_matmul_add('float', 'float', rtol=1e-3) verify_matmul_add('float16', 'float') verify_matmul_add('float16', 'float16', rtol=1e-2) verify_matmul_add('int8', 'int32') +def test_matmul_add_igemm(): + verify_matmul_add_igemm('int8', 'int32') + def test_batch_matmul(): verify_batch_matmul('float', 'float') verify_batch_matmul('float16', 'float') @@ -86,4 +148,5 @@ def test_batch_matmul(): if __name__ == "__main__": test_matmul_add() test_batch_matmul() + test_matmul_add_igemm()