diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index e43fd5f7a2e7..ba1278b2484e 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -298,6 +298,18 @@ struct NdarraySizeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in one_hot operators */ +struct OneHotAttrs : public tvm::AttrsNode { + Integer depth; + Integer axis; + + TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") { + TVM_ATTR_FIELD(depth).set_default(NullValue()) + .describe("Defining the depth of the one hot dimension."); + TVM_ATTR_FIELD(axis).set_default(-1) + .describe("The axis at which the input arrays are expand dims."); + } +}; // struct OneHotAttrs } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index bbc0fec67bf6..a93cb0adf6ee 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -894,6 +894,21 @@ def _transform_mask(stride_dim, ellipsis_mask): return _op.reshape(out, newshape=tuple(final_output)) return _impl +def _one_hot(): + def _impl(inputs, attr, params): + depth = _get_num_param(params, inputs.pop(1)) + on_value = _get_num_param(params, inputs.pop(1)) + off_value = _get_num_param(params, inputs.pop(1)) + inputs.append(tvm.relay.const(on_value, dtype=on_value.dtype)) + inputs.append(tvm.relay.const(off_value, dtype=off_value.dtype)) + axis = int(attr["axis"]) + new_input = inputs[0:3] + return AttrCvt(op_name="one_hot", + extras={'depth': tvm.const(depth, 'int32'), + 'axis': tvm.const(axis, 'int32')}, + ignores=['TI'])(new_input, attr) + return _impl + def _pad(name): def _impl(inputs, attr, params): padlist = _get_param(params, inputs[1]) @@ -1284,6 +1299,7 @@ def _impl(inputs, attr, params): 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), 'NotEqual' : _broadcast('not_equal'), + 'OneHot' : _one_hot(), 'Pack' : _pack(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 51e761516eed..922ef800cdf7 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -52,7 +52,7 @@ _reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective) _reg.register_schedule("sequence_mask", schedule_injective) - +_reg.register_schedule("one_hot", schedule_injective) # layout_transform _reg.register_schedule("layout_transform", schedule_injective) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 48d3d2032f80..b4ee872cc3d5 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -264,3 +264,8 @@ class MaxPool2DAttrs(Attrs): @register_relay_attr_node class AvgPool2DAttrs(Attrs): """Attributes used in avg_pool2d operators""" + + +@register_relay_attr_node +class OneHotAttrs(Attrs): + """Attributes used in one_hot operators""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 5d8d28006ecb..f38f7c15e863 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -340,6 +340,51 @@ def arange(start, stop=None, step=None, dtype="float32"): return _make.arange(start, stop, step, dtype) +def one_hot(data, on_value=None, off_value=None, depth=None, axis=None): + """Onehot + + This operator takes in a 1-D(n) or more dimension tensor and expand the + dimension by the specified depths using the specified value to (n, depth). + + Parameters + ---------- + data: tvm.relay.Expr + The input data to the operator + + depth: int + A scalar defining the depth of the one hot dimension. + + on_value: tvm.relay.Expr + The input data defining the value to fill in output when indices[j] = i. + + off_value: tvm.relay.Expr, + The input data defining the value to fill in output when indices[j] != i. + + axis: int, optional + The axis along which to add depth shape. The default axis is -1. + + Returns + ------- + result : tvm.relay.Expr + The resulting one-hot tensor. + + Examples + -------- + .. code-block:: python + + indices = [0, 2, -1, 1] + depth = 3 + relay.one_hot(data, depth, + on_value=5.0, off_value=0.0, + axis=-1) # output: [4 x 3] + # [[5.0, 0.0, 0.0], # one_hot(0) + # [0.0, 0.0, 5.0], # one_hot(2) + # [0.0, 0.0, 0.0], # one_hot(-1) + # [0.0, 5.0, 0.0]] # one_hot(1) + """ + + return _make.one_hot(data, depth, on_value, off_value, axis) + def repeat(data, repeats, axis): """Repeats elements of an array. By default, repeat flattens the input array into 1-D and then repeats the elements. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 03a92b35d396..b3f8a0ea4a2f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2482,5 +2482,108 @@ Examples:: .set_attr("FTVMCompute", SequenceMaskCompute) .set_attr("TOpPattern", kInjective); +// one_hot operator +TVM_REGISTER_NODE_TYPE(OneHotAttrs); + +bool OneHotRel(const Array& types, + int num_inputs, + const Attrs& raw_attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + const auto attrs = raw_attrs.as(); + CHECK(attrs != nullptr); + + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "one_hot: expect input data type to be TensorType but get " + << types[0]; + return false; + } + + int depth = static_cast(attrs->depth->value); + CHECK_GT(depth, 0) + << "Invalid one_hot attributes (depth): " << attrs->depth; + const auto* on_value = types[1].as(); + const auto* off_value = types[2].as(); + if (on_value == nullptr || off_value == nullptr) { + return false; + } + + CHECK_EQ(on_value->shape.size(), 0) << "on_value should be a scalar"; + CHECK_EQ(off_value->shape.size(), 0) << "off_value should be a scalar"; + + int axis; + if (!attrs->axis.defined()) { + axis = static_cast(data->shape.size()); + } else { + axis = static_cast(attrs->axis->value); + CHECK_GE(axis, -1) + << "axis should be greater equal than -1."; + CHECK_LT(axis, static_cast(data->shape.size())) + << "axis should be within the input dimension range."; + if (axis < 0) { + axis = static_cast(data->shape.size()); + } + } + + std::vector oshape; + const auto ndim_data = static_cast(data->shape.size()); + + oshape.reserve(ndim_data + 1); + for (int i = 0; i < axis; ++i) { + oshape.emplace_back(data->shape[i]); + } + if (axis == ndim_data) { + oshape.emplace_back(depth); + } else { + oshape.emplace_back(depth); + for (int i = axis; i < ndim_data; ++i) { + oshape.emplace_back(data->shape[i]); + } + } + reporter->Assign(types[3], TensorTypeNode::make(Array(oshape), + on_value->dtype)); + return true; +} + +Array OneHotCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + Tensor on_value = inputs[1]; + Tensor off_value = inputs[2]; + + return Array{ topi::one_hot(inputs[0], param->depth, on_value, off_value, param->axis) }; +} + +Expr MakeOneHot(Expr data, + Integer depth, + Expr on_value, + Expr off_value, + Integer axis) { + auto attrs = make_node(); + attrs->depth = std::move(depth); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("one_hot"); + return CallNode::make(op, {data, on_value, off_value}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.one_hot") +.set_body_typed(MakeOneHot); + +RELAY_REGISTER_OP("one_hot") +.describe(R"code(Returns one-hot array within a given interval-depth. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.OneHotAttrs") +.set_num_inputs(3) +.set_support_level(3) +.add_type_rel("OneHot", OneHotRel) +.set_attr("TOpPattern", kInjective) +.set_attr("FTVMCompute", OneHotCompute) +.set_attr("AnyCodegenStrategy", kVariableDimensions); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f1d91a255fbb..d42ed68e2370 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -685,6 +685,49 @@ def verify_gather_nd(xshape, yshape, y_data): verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]]) verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) +def test_one_hot_infer_type(): + def verify_one_hot(dshape, depth, oshape, axis=None): + input = relay.var("input", relay.TensorType(dshape, "int32")) + on_value = relay.var("on_value", relay.scalar_type("float32")) + off_value = relay.var("off_value", relay.scalar_type("float32")) + y = relay.one_hot(input, on_value=on_value, off_value=off_value, depth=depth, axis=axis) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType(oshape, "float32") + + d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3") + + verify_one_hot((d1,), 2, (d1,2), -1) + verify_one_hot((4,), 3, (4, 3)) + verify_one_hot((3, 3), 4, (3, 3 ,4)) + verify_one_hot((d1, d2), 5, (d1, d2, 5), -1) + +def test_one_hot(): + def verify_one_hot(src_shape, depth, on_value_data, off_value_data, axis=None): + data_dtype = "int32" + value_dtype = "float32" + shape_size = 1 + for i in range(len(src_shape)): + shape_size = shape_size * src_shape[i] + input_data = np.arange(shape_size, dtype=data_dtype).reshape((src_shape)) + input = relay.var("input", relay.TensorType(input_data.shape, data_dtype)) + on_value = relay.var("on_value", relay.scalar_type(value_dtype)) + off_value = relay.var("off_value", relay.scalar_type(value_dtype)) + z = relay.one_hot(input, on_value=on_value, off_value=off_value, depth=depth, axis=axis) + + func = relay.Function([input, on_value, off_value], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + on_value_npy = np.array(on_value_data, dtype=value_dtype) + off_value_npy = np.array(off_value_data, dtype=value_dtype) + op_res = intrp.evaluate(func)(input_data, on_value_npy, off_value_npy) + ref_res = on_value_data * np.eye(depth)[input_data] + ref_res[ref_res == 0] = off_value_data + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + verify_one_hot((4,), 4, 1.0, 0.0, -1) + if __name__ == "__main__": test_arange() test_cast() @@ -715,3 +758,5 @@ def verify_gather_nd(xshape, yshape, y_data): test_tile() test_repeat() test_gather_nd() + test_one_hot_infer_type() + test_one_hot() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index e8a65b05a42c..dee6cb492bed 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -660,6 +660,62 @@ inline Tensor take(const Tensor& a, } +/*! +* \brief OneHot elements from an flattened input array of depth along an axis. +* +* \param a The source array. +* \param depth The depth of the one hot dimension to expand. +* \param on_value The value to fill in output when indices[j] = i. +* \param off_value The value to fill in output when indices[j] != i. +* \param axis The axis over which to select values. By default, +* the flattened input array is used. +* \param name The name of the operation. +* \param tag The tag to mark the operation. +* +* \return A Tensor whose op member is the take operation +*/ +inline Tensor one_hot(const Tensor& a, + int depth, + const Tensor& on_value, + const Tensor& off_value, + int axis, + std::string name = "T_one_hot", + std::string tag = kInjective) { + int input_shape = static_cast(a->shape.size()); + CHECK_GE(axis, -1) << "axis out of bounds, must >= -1" << axis; + CHECK_LT(axis, static_cast(a->shape.size())) + << "axis out of bounds, must < a->shape.size()" + << "(" << axis << "," << a->shape.size() << ")"; + + Array out_shape; + for (int i = 0; i < input_shape; ++i) { + if (axis == static_cast(i)) { + out_shape.push_back(depth); + out_shape.push_back(a->shape[i]); + } else { + out_shape.push_back(a->shape[i]); + } + } + if (axis < 0) { + out_shape.push_back(depth); + axis = input_shape; + } + + return compute( + out_shape, [&](const Array& indices) { + Array real_indices; + for (int j = 0; j < axis; ++j) { + real_indices.push_back(indices[j]); + } + if (axis < input_shape) { + for (int j = axis + 1; j < static_cast(indices.size()); ++j) { + real_indices.push_back(indices[j]); + } + } + Expr ret = tvm::ir::Select::make(a(real_indices) == indices[axis], on_value(), off_value()); + return ret; + }, name, tag); +} /*! * \brief Mask the out-of-boundary elements of each sequence. * diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 5e87933c2806..1931b16090d1 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -518,3 +518,30 @@ def where(condition, x, y): A Tensor selected from x or y depending on condition. """ return cpp.where(condition, x, y) + +def one_hot(data, depth, on_value, off_value, axis=-1): + """Creates a one_hot tensor from an input tensor along the axis. + + Parameters + ---------- + data : tvm.Tensor + n-D input, can be any layout. + + depth : Expr.Constant + A scalar defining the depth of the one hot dimension. + + on_value : tvm.Tensor + The input data defining the value to fill in output when indices[j] = i. + + off_value : tvm.Tensor + The input data defining the value to fill in output when indices[j] != i. + + axis : Expr.Constant, optional + The axis along which to add depth shape. The default axis is -1. + + Returns + ------- + result : tvm.Tensor + The resulting one-hot tensor. + """ + return cpp.one_hot(data, depth, on_value, off_value, axis) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 799b660df3b8..c0f21f91d5f1 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -157,7 +157,6 @@ TVM_REGISTER_GLOBAL("topi.sin") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = sin(args[0]); }); - TVM_REGISTER_GLOBAL("topi.tanh") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = tanh(args[0]); @@ -358,6 +357,13 @@ TVM_REGISTER_GLOBAL("topi.take") } }); +TVM_REGISTER_GLOBAL("topi.one_hot") +.set_body([](TVMArgs args, TVMRetValue *rv) { + int depth = args[1]; + int axis = args[4]; + *rv = one_hot(args[0], depth, args[2], args[3], axis); +}); + TVM_REGISTER_GLOBAL("topi.sequence_mask") .set_body([](TVMArgs args, TVMRetValue *rv) { double pad_val = args[2]; diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 64305b4a52cc..078192672ab0 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -312,6 +312,45 @@ def check_device(device): for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) + +def verify_one_hot(src_shape, depth, on_value_data, off_value_data, axis=-1): + src_dtype = "int32" + value_dtype = "float32" + A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="indices") + on_value = tvm.placeholder(shape=(), dtype=value_dtype, name="on_value") + off_value = tvm.placeholder(shape=(), dtype=value_dtype, name="off_value") + out_tensor = topi.one_hot(A, depth, on_value, off_value, axis=axis) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(out_tensor) + + foo = tvm.build(s, [A, on_value, off_value, out_tensor], device, name="one_hot") + shape_size = 1 + for i in range(len(src_shape)): + shape_size = shape_size * src_shape[i] + data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) + on_value_npy = np.array(on_value_data, dtype=value_dtype) + off_value_npy = np.array(off_value_data, dtype=value_dtype) + out_npys = on_value_data * np.eye(depth)[data_npy] + out_npys[out_npys == 0] = off_value_data + + data_nd = tvm.nd.array(data_npy, ctx) + on_value_nd = tvm.nd.array(on_value_npy, ctx) + off_value_nd = tvm.nd.array(off_value_npy, ctx) + out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=value_dtype) + foo(data_nd, on_value_nd, off_value_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys) + + for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: + check_device(device) + + def verify_strided_slice(in_shape, begin, end, strides=None): A = tvm.placeholder(shape=in_shape, name="A") strides = [1,1,1] if strides is None else strides @@ -596,6 +635,15 @@ def test_take(): verify_take((3,4), [0, 2], axis=0, mode="fast") verify_take((3,4), [0, 2], axis=1, mode="fast") +def test_one_hot(): + verify_one_hot((3,), 3, 1.0, 0.0, axis=-1) + verify_one_hot((4,), 4, 1.0, 0.0, axis=-1) + verify_one_hot((4,), 4, 2.0, 1.0, axis=-1) + verify_one_hot((4,), 4, 5.0, 2.0, axis=-1) + verify_one_hot((2,2), 6, 5.0, 2.0, axis=-1) + verify_one_hot((2,2), 6, 5.0, -2.0, axis=-1) + verify_one_hot((5,), 10, 5.0, -1.0, axis=-1) + def test_gather_nd(): for indices_dtype in ['int32', 'float32']: verify_gather_nd((4,), [[1.8]], indices_dtype) @@ -793,3 +841,4 @@ def check_device(device): test_sequence_mask() test_ndarray_size() test_where_fusion() + test_one_hot()