diff --git a/topi/python/topi/cuda/injective.py b/topi/python/topi/cuda/injective.py index b77a97924716..eb7019bd7654 100644 --- a/topi/python/topi/cuda/injective.py +++ b/topi/python/topi/cuda/injective.py @@ -40,13 +40,20 @@ def schedule_injective_from_existing(sch, out): num_thread = tvm.target.Target.current(allow_none=False).max_num_threads max_block = 256 + # vectorize on fp16 data type. This allows to better utilize the memory + # bandwidth. + vector_width = 4 if out.dtype == "float16" else 1 + try: const_size = util.get_const_int(util.prod(out.shape)) - max_block = 256 - need_block_split = const_size > max_block * num_thread + need_block_split = const_size > max_block * num_thread * vector_width except ValueError: need_block_split = False + if vector_width > 1: + fused, v = sch[out].split(fused, vector_width) + sch[out].vectorize(v) + if need_block_split: xo, xi = sch[out].split(fused, factor=num_thread * max_block) bx, tx = sch[out].split(xi, factor=num_thread) diff --git a/topi/tests/python/test_topi_relu.py b/topi/tests/python/test_topi_relu.py index 5c41647846cb..414edbca4f0f 100644 --- a/topi/tests/python/test_topi_relu.py +++ b/topi/tests/python/test_topi_relu.py @@ -20,11 +20,20 @@ import tvm import topi from topi.util import get_const_tuple - +from tvm.contrib.nvcc import parse_compute_version from common import get_all_backend -def verify_relu(m, n): - A = tvm.placeholder((m, n), name='A') +def skip_test(dtype, device): + if dtype == "float16" and device == "cuda": + major, minor = parse_compute_version(tvm.gpu(0).compute_version) + # fp16 starts from 5.3 + if major < 6 or (major == 5 and minor < 3): + print("skip because gpu does not support fp16") + return True + return False + +def verify_relu(m, n, dtype="float32"): + A = tvm.placeholder((m, n), name='A', dtype=dtype) B = topi.nn.relu(A) a_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(A.shape)).astype(A.dtype) @@ -35,6 +44,8 @@ def check_device(device): if not ctx.exist: print("Skip because %s is not enabled" % device) return + if skip_test(dtype, device): + return print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_elemwise(B) @@ -87,12 +98,12 @@ def _prelu_numpy(x, W): tvm.testing.assert_allclose(b.asnumpy(), out_np, rtol=1e-5) def test_relu(): - verify_relu(10, 128) + verify_relu(10, 128, "float32") + verify_relu(128, 64, "float16") def test_schedule_big_array(): verify_relu(1024 * 100 , 512) - def test_leaky_relu(): verify_leaky_relu(100, 0.1) diff --git a/topi/tests/python/test_topi_tensor.py b/topi/tests/python/test_topi_tensor.py index 465d98e5f082..84718ff3a647 100644 --- a/topi/tests/python/test_topi_tensor.py +++ b/topi/tests/python/test_topi_tensor.py @@ -19,6 +19,16 @@ import tvm import topi from tvm.contrib.pickle_memoize import memoize +from tvm.contrib.nvcc import parse_compute_version + +def skip_test(dtype, device): + if dtype == "float16" and device == "cuda": + major, minor = parse_compute_version(tvm.gpu(0).compute_version) + # fp16 starts from 5.3 + if major < 6 or (major == 5 and minor < 3): + print("skip because gpu does not support fp16") + return True + return False def verify_elemwise_sum(num_args, dtype): shape = (3,5,4) @@ -84,18 +94,43 @@ def check_device(device): for device in ["llvm"]: check_device(device) +def verify_vectorization(n, m, dtype): + def check_device(device): + if not tvm.runtime.enabled(device): + print("Skip because %s is not enabled" % device) + return + if skip_test(dtype, device): + return + with tvm.target.create(device): + ctx = tvm.context(device, 0) + A = tvm.placeholder((n, m), name='A', dtype=dtype) + B = tvm.compute((n, m), lambda i, j: + A[i, j] + tvm.const(1, A.dtype), name='B') + S = topi.generic.schedule_elemwise(B) + + fun = tvm.build(S, [A, B], device) + np_A = tvm.nd.empty((n, m), A.dtype, ctx).copyfrom( + np.random.uniform(size=(n, m))) + np_B = tvm.nd.empty((n, m), B.dtype, ctx) + fun(np_A, np_B) + tvm.testing.assert_allclose(np_B.asnumpy(), np_A.asnumpy() + 1, rtol=1e-5) + + for device in ["cuda"]: + check_device(device) + +def test_vectorization(): + verify_vectorization(128, 64, "float16") def test_elemwise_sum(): verify_elemwise_sum(1, "float32") verify_elemwise_sum(5, "float32") verify_elemwise_sum(4, "int32") - def test_full(): verify_full((3,4,5), "float32", 3.14) verify_full((10,), "int32", 7) - if __name__ == "__main__": test_elemwise_sum() test_full() + test_vectorization()