Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][Relay][TensorFlow] Add OneHot operator #3781

Merged
merged 19 commits into from
Aug 22, 2019
Merged
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ List of operators
topi.argsort
topi.topk
topi.sequence_mask
topi.one_hot


List of schedules
Expand Down Expand Up @@ -173,6 +174,7 @@ topi
.. autofunction:: topi.argsort
.. autofunction:: topi.topk
.. autofunction:: topi.sequence_mask
.. autofunction:: topi.one_hot

topi.nn
~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.nn.batch_matmul
tvm.relay.contrib.adaptive_max_pool2d
tvm.relay.contrib.adaptive_avg_pool2d
tvm.relay.one_hot


**Level 11: Dialect Operators**
Expand Down Expand Up @@ -350,6 +351,7 @@ Level 10 Definitions
.. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
.. autofunction:: tvm.relay.one_hot


Level 11 Definitions
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,22 @@ struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
}
};

/*! \brief Attributes used in one-hot operator */
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
int depth;
int axis;
DataType dtype;

TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") {
TVM_ATTR_FIELD(depth).set_default(1)
.describe("Depth of the one hot dimension.");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis to fill.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Output data type.");
}
}; // struct OneHotAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,21 @@ def _impl(inputs, attr, params):
return get_relay_op('log')(add_out)
return _impl

def _one_hot():
def _impl(inputs, attr, params):
depth = int(_get_num_param(params, inputs[1]))
dtype = attr['T'].name

on_value = _get_num_param(params, inputs[2])
off_value = _get_num_param(params, inputs[3])
new_inputs = [inputs[0], \
tvm.relay.const(on_value, dtype), \
tvm.relay.const(off_value, dtype)]
return AttrCvt('one_hot',
ignores=['TI'],
extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr)
return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1284,6 +1299,7 @@ def _impl(inputs, attr, params):
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)
_reg.register_schedule("sequence_mask", schedule_injective)
_reg.register_schedule("one_hot", schedule_injective)


# layout_transform
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,3 +748,47 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
[[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]]
"""
return _make.sequence_mask(data, valid_length, mask_value, axis)

def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value.
Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.

Parameters
----------
indices : relay.Expr
Locations to set to on_value.

on_value : relay.Expr
Value to fill at indices.

off_value : relay.Expr
Value to fill at all other positions besides indices.

depth : int
Depth of the one-hot dimension.

axis : int
Axis to fill.

dtype : str
Data type of the output tensor.

Returns
-------
ret : relay.Expr
The one-hot tensor.

Examples
--------
.. code-block:: python

indices = [0, 1, 2]

relay.one_hot(indices, 3) =
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
"""
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
89 changes: 89 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2482,5 +2482,94 @@ Examples::
.set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// relay.one_hot
TVM_REGISTER_NODE_TYPE(OneHotAttrs);

bool OneHotRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [indices, on_value, off_value, result]
CHECK_EQ(types.size(), 4);
const auto* indices = types[0].as<TensorTypeNode>();
CHECK(indices);

const auto param = attrs.as<OneHotAttrs>();
CHECK_GT(param->depth, 0);

Array<IndexExpr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
for (int i = 0; i < ndim; i++) {
if (i == true_axis) {
oshape.push_back(Integer(param->depth));
} else {
oshape.push_back(indices->shape[indices_index++]);
}
}

reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype));
return true;
}

Array<Tensor> OneHotCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<OneHotAttrs>();
CHECK(param != nullptr);
return Array<Tensor> {
topi::one_hot(inputs[0],
inputs[1](),
inputs[2](),
param->depth,
param->axis,
param->dtype)
};
}

Expr MakeOneHot(Expr indices,
Expr on_value,
Expr off_value,
int depth,
int axis,
DataType dtype) {
auto attrs = make_node<OneHotAttrs>();
attrs->depth = std::move(depth);
attrs->axis = axis;
attrs->dtype = dtype;
static const Op& op = Op::Get("one_hot");
return CallNode::make(op, {indices, on_value, off_value}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.one_hot")
.set_body_typed(MakeOneHot);

RELAY_REGISTER_OP("one_hot")
.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
other locations take value 0. Final dimension is <indices dimensions> x depth.

**indices** Locations to set to 1.

**on_value** Value to fill at indices.

**off_value** Value to fill at all other positions besides indices.

**depth** Depth of the one-hot dimension.

**axis** Axis to fill.

**dtype**)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.OneHotAttrs")
.set_num_inputs(3)
.add_argument("indices", "Tensor", "Locations to set to on_value.")
.add_argument("on_value", "Expr", "Value to fill at indices.")
.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
.set_support_level(10)
.add_type_rel("OneHot", OneHotRel)
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

} // namespace relay
} // namespace tvm
19 changes: 19 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,6 +2158,24 @@ def test_placeholder():
compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0',
init_global_variables=True)

#######################################################################
# OneHot
# ----------------------
def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype):
inp_array1 = np.random.randint(0, 5, size=indices_shape)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype)
compare_tf_with_tvm(inp_array1, in1.name, out.name)

def test_forward_one_hot():
_test_forward_one_hot((3,), 3, 1, 0, -1, "int32")
_test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
_test_forward_one_hot((2, 2), 5, 2, -2, 0, "int32")
_test_forward_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32")
_test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")


#######################################################################
# Main
Expand Down Expand Up @@ -2193,6 +2211,7 @@ def test_placeholder():
test_forward_right_shift()
test_forward_left_shift()
test_forward_truncatemod()
test_forward_one_hot()

# Activations
test_forward_sigmoid()
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,45 @@ def _verify(data_shape, mask_value, axis, dtype, itype):
_verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64')
_verify((5, 8, 3), 0.1, 1, 'float64', 'float32')

def test_one_hot():
def _get_oshape(indices_shape, depth, axis):
oshape = []
true_axis = len(indices_shape) if axis == -1 else axis
ndim = len(indices_shape) + 1
indices_index = 0
for i in range(0, ndim):
if i == true_axis:
oshape.append(depth)
else:
oshape.append(indices_shape[indices_index])
indices_index += 1

return oshape

def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
indices = relay.var("indices", relay.TensorType(indices_shape, "int32"))
on_value_const = relay.const(on_value)
off_value_const = relay.const(off_value)
out = relay.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
checked = run_infer_type(out)
assert checked.checked_type == relay.ty.TensorType(_get_oshape(indices_shape, depth, axis), dtype)
func = relay.Function([indices], out)
indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32")
out_np = topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
out_relay = intrp.evaluate(func)(indices_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)

_verify((3,), 3, 1, 0, -1, "int32")
_verify((3,), 3, 1.0, 0.0, -1, "float32")
_verify((2, 2), 5, 2, -2, 0, "int32")
_verify((2, 2), 5, 0.5, -0.5, 1, "float32")
_verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")

if __name__ == "__main__":
test_adaptive_pool2d()
test_collapse_sum_like()
Expand All @@ -306,4 +345,5 @@ def _verify(data_shape, mask_value, axis, dtype, itype):
test_shape_of()
test_sequence_mask()
test_ndarray_size()
test_one_hot()

50 changes: 50 additions & 0 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1247,5 +1247,55 @@ inline Tensor ndarray_size(const Tensor& src,
}, name, tag);
}

/*!
* \brief Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value.
* \param indices locations to set to on_value.
* \param on_value value that locations represented by indices take on.
* \param off_value value that other locations take on.
* \param depth depth of the one-hot dimension.
* \param axis axis to fill.
* \param dtype data type of the output tensor.
* \param name output tensor name.
* \param tag output tensor tag.
* \return one-hot tensor.
*/
inline Tensor one_hot(const Tensor& indices,
const Expr on_value,
const Expr off_value,
int depth,
int axis,
const Type& dtype,
const std::string name = "T_one_hot",
const std::string tag = kInjective) {
Array<Expr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
int true_axis = (axis == -1) ? indices->shape.size() : axis;
for (int i = 0; i < ndim; i++) {
if (i == true_axis) {
oshape.push_back(Integer(depth));
} else {
oshape.push_back(indices->shape[indices_index++]);
}
}

Expr on_value_cast = cast(dtype, on_value);
Expr off_value_cast = cast(dtype, off_value);
return compute(oshape, [&](const Array<Var>& iter_vars) {
Array<Var> indices_indices;
for (size_t i = 0; i < iter_vars.size(); i++) {
if (static_cast<int>(i) == true_axis) {
continue;
}

indices_indices.push_back(iter_vars[i]);
}

auto idx = iter_vars[true_axis];
return ir::Select::make(indices(indices_indices) == idx, on_value_cast, off_value_cast);
}, name, tag);
}

} // namespace topi
#endif // TOPI_TRANSFORM_H_
1 change: 1 addition & 0 deletions topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
from .pool_grad_python import pool_grad_nchw
from .one_hot import one_hot
Loading