diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 9ac8bb1fd084a..8f59e08c07979 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -40,6 +40,7 @@ List of operators topi.sigmoid topi.clip topi.cast + topi.reinterpret topi.transpose topi.flip topi.strided_slice @@ -133,6 +134,7 @@ topi .. autofunction:: topi.sigmoid .. autofunction:: topi.clip .. autofunction:: topi.cast +.. autofunction:: topi.reinterpret .. autofunction:: topi.transpose .. autofunction:: topi.flip .. autofunction:: topi.strided_slice diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index dad5eb89a0535..61c9b36e1ffdf 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -114,6 +114,7 @@ This level enables additional math and transform operators. tvm.relay.full tvm.relay.full_like tvm.relay.cast + tvm.relay.reinterpret tvm.relay.split tvm.relay.arange tvm.relay.stack @@ -263,6 +264,7 @@ Level 3 Definitions .. autofunction:: tvm.relay.full .. autofunction:: tvm.relay.full_like .. autofunction:: tvm.relay.cast +.. autofunction:: tvm.relay.reinterpret .. autofunction:: tvm.relay.split .. autofunction:: tvm.relay.arange .. autofunction:: tvm.relay.stack diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0749bbd02f1de..51e761516eed4 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -40,6 +40,7 @@ _reg.register_schedule("repeat", schedule_broadcast) _reg.register_schedule("tile", schedule_broadcast) _reg.register_schedule("cast", schedule_injective) +_reg.register_schedule("reinterpret", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("split", schedule_injective) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 5137a9c469a41..5d8d28006ecb1 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -40,6 +40,26 @@ def cast(data, dtype): return _relay_make.cast(data, dtype) +def reinterpret(data, dtype): + """Reinterpret input tensor to data type. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + dtype: str + The target data type + + Returns + ------- + result : relay.Expr + The reinterpreted result. + """ + from .. import _make as _relay_make + return _relay_make.reinterpret(data, dtype) + + def expand_dims(data, axis, num_newaxis=1): """Insert `num_newaxis` axises at the position given by `axis`. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 59424884ccfe8..92e4fd098258a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -97,6 +97,52 @@ RELAY_REGISTER_OP("cast") .set_attr("TOpPattern", kElemWise) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); +// relay.reinterpret +bool ReinterpretRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "Reinterpret: expect input type to be TensorType but get " << types[0]; + return false; + } + const auto* param = attrs.as(); + reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype)); + return true; +} + +Array ReinterpretCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type, const Target& target) { + const CastAttrs* param = attrs.as(); + CHECK(param != nullptr); + DataType dtype = param->dtype; + return {topi::reinterpret(inputs[0], dtype)}; +} + +Expr MakeReinterpret(Expr data, DataType dtype) { + auto attrs = make_node(); + attrs->dtype = dtype; + static const Op& op = Op::Get("reinterpret"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeReinterpret, args, rv); +}); + +RELAY_REGISTER_OP("reinterpret") + .describe(R"code(Reinterpret the data into a new data type. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .set_attrs_type_key("relay.attrs.CastAttrs") + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Reinterpret", CastRel) + .set_attr("FTVMCompute", ReinterpretCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + // relay.expand_dims TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index da3de2b22f741..01c0a120dbcb9 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -75,6 +75,7 @@ def test_cast(): assert "dtype=" in yy.astext() assert yy.checked_type == relay.TensorType((8, 9, 4), "int32") + def test_clip(): a = relay.var("a", relay.TensorType((10, 4), "float32")) y = relay.clip(a, 1., 4.) @@ -88,6 +89,69 @@ def test_clip(): np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) +def test_reinterpret(): + a = relay.var("a", relay.TensorType((1000, 4), "float32")) + y = relay.reinterpret(a, "int32") + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((1000, 4), "int32") + + data = np.random.randn(1000, 4).astype('float32') * 1000 + intrp = create_executor() + op_res = intrp.evaluate(y, {a: relay.const(data)}) + ref_res = data.view("int32") + np.testing.assert_equal(op_res.asnumpy(), ref_res) + + +def test_approximate_transcendental(): + def C(x): + return relay.expr.const(x, "float32") + + def approx_exp(x): + # An approximation derived from Opus, + # https://github.com/xiph/opus/blob/c1c247/celt/mathops.h#L147-L165 + x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0)) + x = C(127.0) + x * C(1.44269504) + xf = relay.floor(x) + i = relay.cast(xf, "int32") + x = x - xf + Y = C(0.99992522) + x * (C(0.69583354) + x * (C(0.22606716) + x * C(0.078024523))) + exponent = relay.left_shift(i, relay.expr.const(23, "int32")) + exponent = relay.reinterpret(exponent, "float32") + return exponent * Y + + def approximate_sigmoid(x): + y = approx_exp(x) + return y / (y + C(1.0)) + + def approximate_tanh(x): + x = x * C(2.0) + y = approx_exp(x) + return (y - C(1.0)) / (y + C(1.0)) + + a = relay.var("a", relay.TensorType((1000,), "float32")) + y = approximate_sigmoid(a) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((1000,), "float32") + data = np.linspace(-5, 5, 1000).astype("float32") + intrp = create_executor() + op_res = intrp.evaluate(y, {a: relay.const(data)}) + + def reference_sigmoid(x): + return np.exp(-np.logaddexp(0, -x)) + np.testing.assert_allclose(op_res.asnumpy(), reference_sigmoid(data), atol=2e-5, rtol=1e-9) + + y = approximate_tanh(a) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((1000,), "float32") + data = np.linspace(-5, 5, 1000).astype("float32") + intrp = create_executor() + op_res = intrp.evaluate(y, {a: relay.const(data)}) + + def reference_tanh(x): + return np.tanh(x) + np.testing.assert_allclose(op_res.asnumpy(), reference_tanh(data), atol=4e-5, rtol=1e-9) + + def test_squeeze(): def verify_squeeze(shape, dtype, axis): x = relay.var("x", relay.TensorType(shape, dtype)) diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index b6e6adad0715b..000567eeae140 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -269,14 +269,34 @@ inline Tensor cast(const Tensor& x, } /*! -* \brief Creates an operation that sum each element of a tensor -* -* \param xs The input tensor array -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the sum operation -*/ + * \brief Reinterpret each element of x to the given type. + + * \param x The input tensor + * \param type The type to cast to + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reinterpret operation + */ +inline Tensor reinterpret(const Tensor& x, Type type, std::string name = "tensor", + std::string tag = kElementWise) { + return compute(x->shape, + [&](const Array& i) { + return tvm::ir::Call::make(type, "reinterpret", {x(i)}, + tvm::ir::Call::PureIntrinsic); + }, + name, tag); +} + +/*! + * \brief Creates an operation that sum each element of a tensor + * + * \param xs The input tensor array + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the sum operation + */ inline Tensor elemwise_sum(const Array& xs, std::string name = "T_elemwise_sum", std::string tag = kElementWise) { diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 406d489696825..87ac06c76c751 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -343,3 +343,21 @@ def cast(x, dtype): return tvm.compute( x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE) return tvm.make._cast(dtype, x) + +def reinterpret(x, dtype): + """Reinterpret input to specified data type. + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + dtype : str + Data type. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return cpp.reinterpret(x, dtype) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 44134d7c2d67f..6c5a0b438cb22 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -193,6 +193,12 @@ TVM_REGISTER_GLOBAL("topi.cast") *rv = cast(args[0], args[1]); }); + +TVM_REGISTER_GLOBAL("topi.reinterpret") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = reinterpret(args[0], args[1]); + }); + TVM_REGISTER_GLOBAL("topi.elemwise_sum") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = elemwise_sum(args[0]); diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 7f2c73e003903..4a99d5f9acaa2 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -45,6 +45,29 @@ def check_device(device): check_device(device) +def verify_reinterpret(in_shape, in_dtype, out_dtype, generator): + A = tvm.placeholder(shape=in_shape, name="A", dtype=in_dtype) + B = topi.reinterpret(A, out_dtype) + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_elemwise(B) + foo = tvm.build(s, [A, B], device, name="reinterpret") + data_npy = generator(in_shape).astype(in_dtype) + out_npy = data_npy.view(B.dtype) + data_nd = tvm.nd.array(data_npy, ctx) + out_nd = tvm.nd.array(np.empty(in_shape).astype(B.dtype), ctx) + foo(data_nd, out_nd) + np.testing.assert_equal(out_nd.asnumpy(), out_npy) + + for device in get_all_backend(): + check_device(device) + + def verify_transpose(in_shape, axes): A = tvm.placeholder(shape=in_shape, name="A") B = topi.transpose(A, axes) @@ -434,6 +457,17 @@ def test_expand_dims(): verify_expand_dims((3, 10), (1, 3, 10), -3, 1) +def test_reinterpret(): + verify_reinterpret((1000,), "float32", "int32", + lambda shape: np.random.randn(*shape) * 1000) + verify_reinterpret((1000,), "float16", "int16", + lambda shape: np.random.randn(*shape) * 100) + verify_reinterpret((1000,), "int16", "uint16", + lambda shape: np.random.randint(-1000, 1000, size=shape)) + verify_reinterpret((1000,), "uint32", "int32", + lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape)) + + def test_transpose(): verify_transpose((3, 10, 2), (1, 0, 2)) verify_transpose((3, 10, 5), (2, 0, 1))