Skip to content

Commit

Permalink
[CodeGen][CUDA] Fix issues in cuda codegen (apache#4876)
Browse files Browse the repository at this point in the history
- Do not emit __shared__ etc. as part of type for casting

- Fix fp16 reduction kernels with compiler errors:

  "no operator "+" matches these operands, volatile half + volatile half

  This patch inserts casts to remove volatile type qualifier following
  volatile loads (fp16 only). CUDA fp16 library headers should add
  volatile member functions.

- Update have_fp16 to include compute 6.1 GPUs, which do support fp16,
  although their fp16 throughput is low. Updated tests.

Signed-off-by: Wei Pan <[email protected]>
  • Loading branch information
wpan11nv authored and alexwong committed Feb 26, 2020
1 parent d992041 commit e229c49
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 54 deletions.
6 changes: 1 addition & 5 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,7 @@ def have_fp16(compute_version):
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#arithmetic-instructions
if major == 5 and minor == 3:
return True
# NOTE: exclude compute capability 6.1 devices although it is actually available
# to compute fp16, because these devices only have low-rate fp16 performance.
if major == 6 and minor != 1:
return True
if major == 7:
if major >= 6:
return True

return False
Expand Down
13 changes: 7 additions & 6 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,15 @@ std::string CodeGenC::GetBufferRef(
if (alloc_storage_scope_.count(buffer)) {
scope = alloc_storage_scope_.at(buffer);
}
bool is_vol = volatile_buf_.count(buffer) != 0;
bool is_vol = IsVolatile(buffer);
if (t.lanes() == 1) {
if (!HandleTypeMatch(buffer, t) || is_vol) {
os << "((";
if (is_vol) {
os << "volatile ";
}
if (scope.length() != 0) {
// Scope may not be part of type.
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os);
}
os << ' ';
Expand Down Expand Up @@ -189,15 +190,15 @@ std::string CodeGenC::GetBufferRef(
if (is_vol) {
os << "volatile ";
}
if (scope.length() != 0) {
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
if (scope.length() != 0) {
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os);
}
os << ' ';
Expand Down Expand Up @@ -620,14 +621,14 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
// delcare type.
if (op->dtype.lanes() == 1) {
std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index);
os << ref;
HandleVolatileLoads(ref, op, os);
} else {
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
PrimExpr base;
if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
os << ref;
HandleVolatileLoads(ref, op, os);
} else {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
Expand Down
34 changes: 33 additions & 1 deletion src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,36 @@ class CodeGenC :
// Print reference to struct location
std::string GetStructRef(
DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
// print reference to a buffer as type t in index.
// Print reference to a buffer as type t in index.
virtual std::string GetBufferRef(
DataType t, const VarNode* buffer, PrimExpr index);

/*!
* \brief Handle volatile loads.
*
* This is to workaround a bug in CUDA cuda_fp16.h. Volatile accesses
* to shared memory are required for reductions. However, __half class
* does not implement volatile member functions. CUDA codegen will cast
* away volatile qualifier from CUDA __half types.
*/
virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op,
std::ostream& os) {
// By default, do nothing but print the loaded value.
os << value;
}

/*!
* \brief Check if scope is part of type in the target language.
*
* **NOTE** In OpenCL, __local is part of type, so "__local int *"
* is legal. This is not the case for CUDA, where "__shared__"
* or "__constant__" is not part of type but a storage class (like
* C/C++ static).
*/
virtual bool IsScopePartOfType() const {
return true;
}

/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
Expand All @@ -205,6 +232,11 @@ class CodeGenC :
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();

/*! \brief Check if buf_var is volatile or not. */
bool IsVolatile(const VarNode *buf_var) const {
return volatile_buf_.count(buf_var) != 0;
}

private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
Expand Down
28 changes: 14 additions & 14 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,6 @@ std::string CodeGenCUDA::Finish() {
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half min(half a, half b)\n"
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
// FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
// which is needed by operations such as softmax.
// However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
// We need to figure out a solution which can satisfy both scenario.
// decl_stream << "__device__ half operator<="
// << "(const volatile __half &a, const volatile __half &b)\n"
// << "{\n return __hlt(a, b);\n}\n";
// decl_stream << "__device__ half operator+"
// << "(const volatile __half &a, const volatile __half &b)\n"
// <<"{\n return __hadd(a, b);\n}\n";
// decl_stream << "__device__ half operator*"
// << "(const volatile __half &a, const volatile __half &b)\n"
// << "{\n return __hmul(a, b);\n}\n";
// otherwise simulate computation via float32
decl_stream << "#else\n";
decl_stream << _cuda_half_t_def;
decl_stream << "#endif\n\n";
Expand Down Expand Up @@ -605,5 +591,19 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
return 0;
}

void CodeGenCUDA::HandleVolatileLoads(const std::string& value,
const LoadNode* op, std::ostream& os) {
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
if (op->dtype.is_float16() && IsVolatile(op->buffer_var.get())) {
os << "(";
PrintType(op->dtype, os);
os << ")(" << value << ")";
} else {
os << value;
}
}

} // namespace codegen
} // namespace tvm
9 changes: 9 additions & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ class CodeGenCUDA final : public CodeGenC {
void VisitStmt_(const AttrStmtNode *op) final;

private:
// Handle volatile loads
void HandleVolatileLoads(const std::string& value, const LoadNode* op,
std::ostream& os) final;

// Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final {
return false;
}

// Whether global barrier is needed.
bool need_global_barrier_{false};
// Global barrier state
Expand Down
41 changes: 35 additions & 6 deletions tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
# under the License.
import tvm
import numpy as np
import topi
import unittest
from tvm.contrib.nvcc import parse_compute_version, have_int8
from tvm.contrib.nvcc import have_fp16, have_int8
from tvm.contrib import nvcc

tx = tvm.thread_axis("threadIdx.x")
Expand All @@ -30,11 +31,8 @@ def check_cuda(dtype, n, lanes):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16":
major, minor = parse_compute_version(tvm.gpu(0).compute_version)
# fp16 starts from 5.3
if major < 6 or (major == 5 and minor < 3):
print("skip because gpu does not support fp16")
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
print("skip because gpu does not support int8")
Expand Down Expand Up @@ -291,6 +289,36 @@ def test_cuda_const_float_to_half():
func(a, c)
np.testing.assert_equal(c.asnumpy(), a_np > b.value)

def test_cuda_reduction():
def check_cuda(dtype, m=32, n=32):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return

a = tvm.placeholder((m, n), name="a", dtype=dtype)
b = tvm.placeholder((m, n), name="b", dtype=dtype)
c = a + b
d = a * b
e = topi.elemwise_sum([c, d])
g = topi.sum(e)
with tvm.target.cuda():
sg = topi.generic.schedule_reduce(g)
ctx = tvm.gpu(0)
func = tvm.build(sg, [a, b, g], 'cuda')
a_np = np.random.uniform(size=(m, n)).astype(a.dtype)
b_np = np.random.uniform(size=(m, n)).astype(b.dtype)
g_np = np.sum(np.add(a_np * b_np, a_np + b_np))
a_nd = tvm.nd.array(a_np, ctx)
b_nd = tvm.nd.array(b_np, ctx)
g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), ctx)
func(a_nd, b_nd, g_nd)
tvm.testing.assert_allclose(g_nd.asnumpy(), g_np, rtol=1e-3)

check_cuda("float32")
check_cuda("float16")

if __name__ == "__main__":
test_cuda_vectorize_add()
Expand All @@ -302,3 +330,4 @@ def test_cuda_const_float_to_half():
test_cuda_reducition_binding()
test_rfactor_predicates()
test_cuda_const_float_to_half()
test_cuda_reduction()
14 changes: 3 additions & 11 deletions topi/tests/python/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,9 @@
import tvm
import topi
from topi.util import get_const_tuple
from tvm.contrib.nvcc import parse_compute_version
from tvm.contrib.nvcc import have_fp16
from common import get_all_backend

def skip_test(dtype, device):
if dtype == "float16" and device == "cuda":
major, minor = parse_compute_version(tvm.gpu(0).compute_version)
# fp16 starts from 5.3
if major < 6 or (major == 5 and minor < 3):
print("skip because gpu does not support fp16")
return True
return False

def verify_relu(m, n, dtype="float32"):
A = tvm.placeholder((m, n), name='A', dtype=dtype)
B = topi.nn.relu(A)
Expand All @@ -44,7 +35,8 @@ def check_device(device):
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if skip_test(dtype, device):
if dtype == "float16" and device == "cuda" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because %s does not have fp16 support" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
Expand Down
14 changes: 3 additions & 11 deletions topi/tests/python/test_topi_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,7 @@
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from tvm.contrib.nvcc import parse_compute_version

def skip_test(dtype, device):
if dtype == "float16" and device == "cuda":
major, minor = parse_compute_version(tvm.gpu(0).compute_version)
# fp16 starts from 5.3
if major < 6 or (major == 5 and minor < 3):
print("skip because gpu does not support fp16")
return True
return False
from tvm.contrib.nvcc import have_fp16

def verify_elemwise_sum(num_args, dtype):
shape = (3,5,4)
Expand Down Expand Up @@ -99,7 +90,8 @@ def check_device(device):
if not tvm.runtime.enabled(device):
print("Skip because %s is not enabled" % device)
return
if skip_test(dtype, device):
if dtype == "float16" and device == "cuda" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
with tvm.target.create(device):
ctx = tvm.context(device, 0)
Expand Down

0 comments on commit e229c49

Please sign in to comment.