From 2630ffcbc52973aaf86fd6b7000a6f2f30d5f25c Mon Sep 17 00:00:00 2001 From: wpan11nv <60017475+wpan11nv@users.noreply.github.com> Date: Fri, 17 Jan 2020 18:58:11 -0800 Subject: [PATCH] [CodeGen][CUDA] Improve CUDA vectorizer (#4736) - Fixes issues to enable fp16 vectorizer. Now correct packing and unpacking CUDA code will be emitted. Enabled more unit tests. - Do not emit code to read the first lane from an undef variable int _3; _3 = _3 & ~(0x000000ff << 0) | ... and emit the following code instead: _3 = (((0x000000ff & (_1 >> 0))+(0x000000ff & (_2 >> 0))) << 0); Note that nvcc 10.2 is forgiving and emits the same code for both cases. A warning appears in test_codegen_cuda.py. Signed-off-by: Wei Pan --- include/tvm/runtime/data_type.h | 4 ++ src/codegen/codegen_cuda.cc | 43 +++++++++++++++++++--- src/codegen/literal/cuda_half_t.h | 10 +++++ tests/python/unittest/test_codegen_cuda.py | 24 ++++++------ 4 files changed, 65 insertions(+), 16 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index cb58e9741d1f..7e0ef49154e4 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -92,6 +92,10 @@ class DataType { bool is_float() const { return code() == DataType::kFloat; } + /*! \return whether type is a float16 type. */ + bool is_float16() const { + return is_float() && bits() == 16; + } /*! \return whether type is an int type. */ bool is_int() const { return code() == DataType::kInt; diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 6f394a1e19bd..b6ba17fc381a 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -73,6 +73,7 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#else\n"; decl_stream << _cuda_half_t_def; decl_stream << "#endif\n\n"; + decl_stream << _cuda_half_util; } if (enable_int8_) { @@ -122,8 +123,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (lanes == 1) { os << "half"; } else if (lanes <= 8) { + // Emit CUDA code to access fp16 vector elements. + // + // half4 is stored as uint2 + // + // h4.x is emitted as *(half2*)(&(u2.x)).x + // h4.y is emitted as *(half2*)(&(u2.x)).y + // h4.z is emitted as *(half2*)(&(u2.y)).x + // h4.w is emitted as *(half2*)(&(u2.y)).y + // CHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - os << "float" << lanes / 2; + os << "uint" << lanes / 2; } else { fail = true; } @@ -243,9 +253,12 @@ void CodeGenCUDA::PrintVecBinaryOp( void CodeGenCUDA::PrintVecElemLoad( const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) static const char access[] = {'x', 'y', 'z', 'w'}; - CHECK(i >= 0 && i < 4); + CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); if (t.is_int() && t.bits() == 8) { os << "(0x000000ff & (" << vec << " >> " << i * 8 << "))"; + } else if (t.is_float16()) { + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2]; } else { os << vec << "." << access[i]; } @@ -255,10 +268,17 @@ void CodeGenCUDA::PrintVecElemStore( const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; - CHECK(i >= 0 && i < 4); + CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); if (t.is_int() && t.bits() == 8) { - stream << vec << "=" << vec << " & ~(0x000000ff << " << i * 8 << ") | (" - << value << " << " << i * 8 << ");\n"; + stream << vec << "="; + // Do not read the first undef lane. + if (i != 0) { + stream << vec << " & ~(0x000000ff << " << i * 8 << ") |"; + } + stream << "(" << value << " << " << i * 8 << ");\n"; + } else if (t.is_float16()) { + stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2] << " = " << value << ";\n"; } else { stream << vec << "." << access[i] << " = " << value << ";\n"; } @@ -462,6 +482,19 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N return; } + if (op->dtype.is_float16()) { + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < op->lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_half2(" << v << ", " << v << ")"; + } + os << ')'; + return; + } + std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->dtype, os); diff --git a/src/codegen/literal/cuda_half_t.h b/src/codegen/literal/cuda_half_t.h index 630a7413dc6c..7e9c72e437de 100644 --- a/src/codegen/literal/cuda_half_t.h +++ b/src/codegen/literal/cuda_half_t.h @@ -285,4 +285,14 @@ TVM_XINLINE half __float2half_rn(const float a) { } )"; +static constexpr const char* _cuda_half_util = R"( +// Pack two half values. +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v0 << 16) | v1; +} +)"; + #endif // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_ diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index 27a8d8746dd8..5d05b6d117b7 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -18,21 +18,23 @@ import tvm import numpy as np import unittest -from tvm.contrib.nvcc import have_fp16, have_int8 +from tvm.contrib.nvcc import parse_compute_version, have_int8 from tvm.contrib import nvcc tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") - def test_cuda_vectorize_add(): num_thread = 8 def check_cuda(dtype, n, lanes): if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): print("skip because cuda is not enabled..") return - if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): - print("skip because gpu does not support fp16") + if dtype == "float16": + major, minor = parse_compute_version(tvm.gpu(0).compute_version) + # fp16 starts from 5.3 + if major < 6 or (major == 5 and minor < 3): + print("skip because gpu does not support fp16") return if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version): print("skip because gpu does not support int8") @@ -52,13 +54,13 @@ def check_cuda(dtype, n, lanes): tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1) check_cuda("float32", 64, 2) - check_cuda("int8", 64, 4) - # check_cuda("float16", 64, 2) - - # TODO(tvm-team) fix fp16 codegen here - # or hit an error if it is less frequently used. - # check_cuda("float16", 64, 2) - + check_cuda("float32", 64, 3) + check_cuda("float32", 64, 4) + check_cuda("int8", 64, 4) + check_cuda("float16", 64, 2) + check_cuda("float16", 64, 4) + check_cuda("float16", 64, 6) + check_cuda("float16", 64, 8) def test_cuda_multiply_add(): num_thread = 8