diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index b48f647688c5..241310fd00d4 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -50,6 +50,20 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { std::string CodeGenCUDA::Finish() { if (enable_fp16_) { decl_stream << "#include \n"; + decl_stream << "__device__ half max" \ + "(const half a, const half b)\n" + "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n"; + decl_stream << "__device__ half min(const half a, const half b)\n" + "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n"; + decl_stream << "__device__ half operator+" \ + "(const volatile __half &a, const volatile __half &b)\n" + "{\n return __hadd(a, b);\n}\n"; + decl_stream << "__device__ half operator<=" \ + "(const volatile __half &a, const volatile __half &b)\n" + "{\n return __hlt(a, b);\n}\n"; + decl_stream << "__device__ half operator*" \ + "(const volatile __half &a, const volatile __half &b)\n" + "{\n return __hmul(a, b);\n}\n"; } if (enable_int8_) { diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index d31f4f46f5d7..4a07662554b9 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -21,6 +21,7 @@ from tvm.relay import transform from tvm.relay.testing import ctx_list import topi.testing +from tvm.contrib.nvcc import have_fp16 def run_infer_type(expr): mod = relay.Module.from_expr(expr) @@ -42,11 +43,11 @@ def rsqrt(x): return one / np.sqrt(x) def test_unary_op(): - def check_single_op(opfunc, ref): + def check_single_op(opfunc, ref, dtype): shape = (10, 4) - dtype = 'float32' - tp = relay.TensorType(shape, dtype) - x = relay.var("x", tp) + dtype = dtype + tp = relay.TensorType(shape) + x = relay.var("x", tp, dtype=dtype) y = opfunc(x) # test printer assert ("{}(%x)".format(y.op.name)) in y.astext() @@ -61,6 +62,8 @@ def check_single_op(opfunc, ref): for target, ctx in ctx_list(): # use graph by execuor default for testing, as we need # create function explicitly to avoid constant-folding. + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) @@ -77,22 +80,23 @@ def check_single_op(opfunc, ref): (tvm.relay.cos, np.cos), (tvm.relay.sin, np.sin), (tvm.relay.atan, np.arctan)]: - check_single_op(opfunc, ref) + for dtype in ['float16', 'float32']: + check_single_op(opfunc, ref, dtype) def test_binary_op(): def inst(vars, sh): return [vars.get(s, s) for s in sh] - def check_binary_op(opfunc, ref): + def check_binary_op(opfunc, ref, dtype): # TODO(@jroesch): this piece of code improperly uses type variables. n = tvm.var("n") s1 = (5, n, 5) s2 = (n, 1) t1 = relay.TensorType(s1) t2 = relay.TensorType(s2) - x = relay.var("x", t1) - y = relay.var("y", t2) + x = relay.var("x", t1, dtype=dtype) + y = relay.var("y", t2, dtype=dtype) z = opfunc(x, y) # test printer assert ("{}(%x, %y)".format(z.op.name)) in z.astext() @@ -102,17 +106,19 @@ def check_binary_op(opfunc, ref): if ref is not None: t1 = relay.TensorType((5, 10, 5)) t2 = relay.TensorType((5, 10, 5)) - x = relay.var("x", t1) - y = relay.var("y", t2) + x = relay.var("x", t1, dtype=dtype) + y = relay.var("y", t2, dtype=dtype) z = opfunc(x, y) - x_data = np.random.rand(5, 10, 5).astype(t1.dtype) - y_data = np.random.rand(5, 10, 5).astype(t2.dtype) + x_data = np.random.rand(5, 10, 5).astype(dtype) + y_data = np.random.rand(5, 10, 5).astype(dtype) ref_res = ref(x_data, y_data) func = relay.Function([x, y], z) for target, ctx in ctx_list(): # use graph by execuor default for testing, as we need # create function explicitly to avoid constant-folding. + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) @@ -121,7 +127,8 @@ def check_binary_op(opfunc, ref): (relay.subtract, np.subtract), (relay.multiply, np.multiply), (relay.divide, np.divide)]: - check_binary_op(opfunc, ref) + for dtype in ['float16', 'float32']: + check_binary_op(opfunc, ref, dtype) def test_expand_dims(): @@ -130,226 +137,249 @@ def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): x = relay.Var("x", relay.TensorType(dshape, dtype)) func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis)) for target, ctx in ctx_list(): + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue data = np.random.uniform(size=dshape).astype(dtype) ref_res = data.reshape(oshape) intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) - - verify_expand_dims((3, 10), 'float32', (3, 10, 1, 1), 2, 2) - verify_expand_dims((3, 10), 'float32', (1, 3, 10), -3, 1) + for dtype in ['float16', 'float32']: + verify_expand_dims((3, 10), dtype, (3, 10, 1, 1), 2, 2) + verify_expand_dims((3, 10), dtype, (1, 3, 10), -3, 1) def test_bias_add(): - xshape=(10, 2, 3, 4) - bshape=(2,) - dtype="float32" - x = relay.var("x", shape=xshape) - bias = relay.var("bias") - z = relay.nn.bias_add(x, bias) - zz = run_infer_type(z) - assert "axis=" not in zz.astext() - assert zz.args[1].checked_type == relay.TensorType(bshape) - - func = relay.Function([x, bias], z) - x_data = np.random.uniform(size=xshape).astype(dtype) - y_data = np.random.uniform(size=bshape).astype(dtype) - ref_res = x_data + y_data.reshape((2, 1, 1)) - for target, ctx in ctx_list(): - intrp = relay.create_executor("graph", ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) - np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + for dtype in ['float16', 'float32']: + xshape=(10, 2, 3, 4) + bshape=(2,) + rtol = 1e-2 if dtype is 'float16' else 1e-5 + x = relay.var("x", shape=xshape, dtype=dtype) + bias = relay.var("bias", dtype=dtype) + z = relay.nn.bias_add(x, bias) + zz = run_infer_type(z) + assert "axis=" not in zz.astext() + assert zz.args[1].checked_type == relay.TensorType(bshape, dtype) + + func = relay.Function([x, bias], z) + x_data = np.random.uniform(size=xshape).astype(dtype) + y_data = np.random.uniform(size=bshape).astype(dtype) + ref_res = x_data + y_data.reshape((2, 1, 1)) + for target, ctx in ctx_list(): + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data, y_data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol) def test_expand_dims_infer_type(): - n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = relay.var("x", shape=(n, t, d)) - y = relay.expand_dims(x, axis=2) - assert "axis=2" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, t, 1, 100)) + for dtype in ['float16', 'float32']: + n, t, d = tvm.var("n"), tvm.var("t"), 100 + x = relay.var("x", shape=(n, t, d), dtype=dtype) + y = relay.expand_dims(x, axis=2) + assert "axis=2" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, t, 1, 100), dtype) def test_softmax(): - shape = (10, 4) - x = relay.var("x", shape=shape) - y = relay.nn.softmax(x, axis=1) - assert "nn.softmax" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType(shape) - func = relay.Function([x], y) - x_data = np.random.uniform(size=shape).astype("float32") - ref_res = topi.testing.softmax_python(x_data) - for target, ctx in ctx_list(): - intrp = relay.create_executor("graph", ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x_data) - np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + for dtype in ['float16', 'float32']: + # Softmax accuracy for float16 is poor + if dtype == 'float16': + return + shape = (10, 4) + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.nn.softmax(x, axis=1) + assert "nn.softmax" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType(shape, dtype) + func = relay.Function([x], y) + x_data = np.random.uniform(size=shape).astype(dtype) + ref_res = topi.testing.softmax_python(x_data) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) def test_log_softmax(): - shape = (10, 4) - x = relay.var("x", shape=shape) - y = relay.nn.log_softmax(x, axis=1) - assert "nn.log_softmax" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType(shape) - func = relay.Function([x], y) - x_data = np.random.uniform(size=shape).astype("float32") - ref_res = topi.testing.log_softmax_python(x_data) - for target, ctx in ctx_list(): - intrp = relay.create_executor("graph", ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x_data) - np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + for dtype in ['float16', 'float32']: + # Softmax accuracy for float16 is poor + if dtype == 'float16': + return + shape = (10, 4) + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.nn.log_softmax(x, axis=1) + assert "nn.log_softmax" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType(shape, dtype) + func = relay.Function([x], y) + x_data = np.random.uniform(size=shape).astype(dtype) + ref_res = topi.testing.log_softmax_python(x_data) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) def test_concatenate(): - n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = relay.var("x", shape=(n, t, d)) - y = relay.var("y", shape=(n, t, d)) - z = relay.concatenate((x, y), axis=-1) - assert "axis=" in z.astext() - zz = run_infer_type(z) - assert zz.checked_type == relay.TensorType((n, t, 200)) - - x = relay.exp(x) - z = relay.concatenate((x, y), axis=2) - zz = run_infer_type(z) - assert zz.checked_type == relay.TensorType((n, t, 200)) - - z = relay.concatenate((x, y), axis=1) - zz = run_infer_type(z) - assert zz.checked_type == relay.TensorType((n, t + t, 100)) - - # check shape mismatches (the following case is expected to raise tvm._ffi.base.TVMError. - try: - x = relay.var('p1', shape=(2, 5)) - y = relay.var('p2', shape=(2, 3)) - c = relay.concatenate([x, y], axis=0) - func = relay.Function([x, y], c) - zz = run_infer_type(func) - except tvm._ffi.base.TVMError: - pass - else: - assert False - - x = relay.var("x", shape=(10, 5)) - y = relay.var("y", shape=(10, 5)) - t = relay.var("z", shape=()) - z = relay.concatenate((x, y), axis=1) - z = relay.add(z, t) - # Check result. - func = relay.Function([x, y, t], z) - x_data = np.random.rand(10, 5).astype('float32') - y_data = np.random.rand(10, 5).astype('float32') - t_data = np.random.uniform(size=()).astype('float32') - ref_res = np.concatenate((x_data, y_data), axis=1) + t_data - - for target, ctx in ctx_list(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - intrp2 = relay.create_executor("debug", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01) - op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data) - tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01) + for dtype in ['float16', 'float32']: + n, t, d = tvm.var("n"), tvm.var("t"), 100 + x = relay.var("x", shape=(n, t, d)) + y = relay.var("y", shape=(n, t, d)) + z = relay.concatenate((x, y), axis=-1) + assert "axis=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, t, 200)) + + x = relay.exp(x) + z = relay.concatenate((x, y), axis=2) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, t, 200)) + + z = relay.concatenate((x, y), axis=1) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, t + t, 100)) + + # check shape mismatches (the following case is expected to raise tvm._ffi.base.TVMError. + try: + x = relay.var('p1', shape=(2, 5)) + y = relay.var('p2', shape=(2, 3)) + c = relay.concatenate([x, y], axis=0) + func = relay.Function([x, y], c) + zz = run_infer_type(func) + except tvm._ffi.base.TVMError: + pass + else: + assert False + + x = relay.var("x", shape=(10, 5), dtype=dtype) + y = relay.var("y", shape=(10, 5), dtype=dtype) + t = relay.var("z", shape=(), dtype=dtype) + z = relay.concatenate((x, y), axis=1) + z = relay.add(z, t) + # Check result. + func = relay.Function([x, y, t], z) + x_data = np.random.rand(10, 5).astype(dtype) + y_data = np.random.rand(10, 5).astype(dtype) + t_data = np.random.uniform(size=()).astype(dtype) + ref_res = np.concatenate((x_data, y_data), axis=1) + t_data + + for target, ctx in ctx_list(): + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01) + op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01) def test_dropout(): - n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") - input_ty = relay.TensorType((n, t, d), "float32") - x = relay.var("x", input_ty) - y = relay.nn.dropout(x, rate=0.75) - assert "rate=" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == input_ty + for dtype in ['float16', 'float32']: + n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") + input_ty = relay.TensorType((n, t, d), dtype) + x = relay.var("x", input_ty) + y = relay.nn.dropout(x, rate=0.75) + assert "rate=" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == input_ty def test_batch_norm(): - # beta and gamma ignored - data = relay.var("data", relay.TensorType((3, 2, 1))) - beta = relay.var("beta", relay.TensorType((2,))) - gamma = relay.var("gamma", relay.TensorType((2,))) - moving_mean = relay.var("moving_mean", relay.TensorType((2,))) - moving_var = relay.var("moving_var", relay.TensorType((2,))) - y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, - center=False, scale=False) - yy = run_infer_type(y.astuple()) - assert "center=" in yy.astext() - assert yy.checked_type == relay.ty.TupleType(tvm.convert([ - relay.TensorType((3, 2, 1), "float32"), - relay.TensorType((2,), "float32"), - relay.TensorType((2,), "float32") - ])) - - beta = relay.var("beta", relay.TensorType((3,))) - gamma = relay.var("gamma", relay.TensorType((3,))) - moving_mean = relay.var("moving_mean", relay.TensorType((3,))) - moving_var = relay.var("moving_var", relay.TensorType((3,))) - - y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, - axis=0, center=False, scale=False) - yy = run_infer_type(y.astuple()) - assert yy.checked_type == relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((3, 2, 1), "float32"), - relay.ty.TensorType((3,), "float32"), - relay.ty.TensorType((3,), "float32") - ])) - - # axis=-1 - data = relay.var("data", relay.TensorType((1, 2, 3))) - beta = relay.var("beta", relay.TensorType((3,))) - gamma = relay.var("gamma", relay.TensorType((3,))) - moving_mean = relay.var("moving_mean", relay.TensorType((3,))) - moving_var = relay.var("moving_var", relay.TensorType((3,))) - y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, - axis=-1, center=False, scale=False) - yy = run_infer_type(y.astuple()) - assert yy.checked_type == relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((1, 2, 3), "float32"), - relay.ty.TensorType((3,), "float32"), - relay.ty.TensorType((3,), "float32") - ])) + for dtype in ['float16', 'float32']: + # beta and gamma ignored + data = relay.var("data", relay.TensorType((3, 2, 1), dtype)) + beta = relay.var("beta", relay.TensorType((2,), dtype)) + gamma = relay.var("gamma", relay.TensorType((2,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((2,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((2,), dtype)) + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, + center=False, scale=False) + yy = run_infer_type(y.astuple()) + assert "center=" in yy.astext() + assert yy.checked_type == relay.ty.TupleType(tvm.convert([ + relay.TensorType((3, 2, 1), dtype), + relay.TensorType((2,), dtype), + relay.TensorType((2,), dtype) + ])) + + beta = relay.var("beta", relay.TensorType((3,), dtype)) + gamma = relay.var("gamma", relay.TensorType((3,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((3,), dtype)) + + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, + axis=0, center=False, scale=False) + yy = run_infer_type(y.astuple()) + assert yy.checked_type == relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((3, 2, 1), dtype), + relay.ty.TensorType((3,), dtype), + relay.ty.TensorType((3,), dtype) + ])) + + # axis=-1 + data = relay.var("data", relay.TensorType((1, 2, 3), dtype)) + beta = relay.var("beta", relay.TensorType((3,), dtype)) + gamma = relay.var("gamma", relay.TensorType((3,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((3,), dtype)) + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, + axis=-1, center=False, scale=False) + yy = run_infer_type(y.astuple()) + assert yy.checked_type == relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((1, 2, 3), dtype), + relay.ty.TensorType((3,), dtype), + relay.ty.TensorType((3,), dtype) + ])) def test_dense(): - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - w = relay.var("w", relay.TensorType((2, w), "float32")) - y = relay.nn.dense(x, w, units=2) - assert "units=2" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") + for dtype in ['float16', 'float32']: + # Dense accuracy for float16 is poor + if dtype == 'float16': + return + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + w = relay.var("w", relay.TensorType((2, w), dtype)) + y = relay.nn.dense(x, w, units=2) + assert "units=2" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - wh, ww = tvm.var("wh"), tvm.var("ww") - w = relay.var("w", relay.TensorType((ww, wh), "float32")) - y = relay.nn.dense(x, w) - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, ww), "float32") + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + wh, ww = tvm.var("wh"), tvm.var("ww") + w = relay.var("w", relay.TensorType((ww, wh), dtype)) + y = relay.nn.dense(x, w) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype) - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - w = relay.var("w", relay.IncompleteType()) - y = relay.nn.dense(x, w, units=2) - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") - - x = relay.var("x", shape=(10, 5)) - w = relay.var("w", shape=(2, 5)) - z = relay.nn.dense(x, w) - - # Check result. - func = relay.Function([x, w], z) - x_data = np.random.rand(10, 5).astype('float32') - w_data = np.random.rand(2, 5).astype('float32') - ref_res = np.dot(x_data, w_data.T) - - for target, ctx in ctx_list(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - intrp2 = relay.create_executor("debug", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(x_data, w_data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data, w_data) - tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + w = relay.var("w", relay.IncompleteType()) + y = relay.nn.dense(x, w, units=2) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) + + x = relay.var("x", shape=(10, 5), dtype=dtype) + w = relay.var("w", shape=(2, 5), dtype=dtype) + z = relay.nn.dense(x, w) + + # Check result. + func = relay.Function([x, w], z) + x_data = np.random.rand(10, 5).astype(dtype) + w_data = np.random.rand(2, 5).astype(dtype) + ref_res = np.dot(x_data, w_data.T) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) + op_res2 = intrp2.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) def test_bitserial_dense(): diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index b1aa20ea07df..4a529f4a047f 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -19,6 +19,7 @@ import tvm import topi import topi.testing +from tvm.contrib.nvcc import have_fp16 from common import get_all_backend @@ -53,6 +54,9 @@ def check_device(device): if not ctx.exist: print("Skip because %s is not enabled" % device) return + if in_dtype == "float16" and device == 'cuda' and not have_fp16(ctx.compute_version): + print("Skip because %s does not have fp16 support" % device) + return print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_elemwise(B)