Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][Relay][TensorFlow] Add OneHot operator #3781

Merged
merged 19 commits into from
Aug 22, 2019
Merged
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ List of operators
topi.argsort
topi.topk
topi.sequence_mask
topi.one_hot


List of schedules
Expand Down Expand Up @@ -173,6 +174,7 @@ topi
.. autofunction:: topi.argsort
.. autofunction:: topi.topk
.. autofunction:: topi.sequence_mask
.. autofunction:: topi.one_hot

topi.nn
~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.nn.batch_matmul
tvm.relay.contrib.adaptive_max_pool2d
tvm.relay.contrib.adaptive_avg_pool2d
tvm.relay.one_hot


**Level 11: Dialect Operators**
Expand Down Expand Up @@ -350,6 +351,7 @@ Level 10 Definitions
.. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
.. autofunction:: tvm.relay.one_hot


Level 11 Definitions
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,22 @@ struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
}
};

/*! \brief Attributes used in one-hot operator */
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
int depth;
int axis;
DataType dtype;

TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") {
TVM_ATTR_FIELD(depth).set_default(1)
.describe("Depth of the one hot dimension.");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis to fill.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Output data type.");
}
}; // struct OneHotAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,14 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
ret.append(_op.stack(inputs, axis=0))
return ret

def _mx_one_hot(inputs, attrs):
indices = inputs[0].astype('int32')
depth = attrs.get_int('depth', 0)
dtype = attrs.get_str('dtype', 'int32')
on_value = tvm.relay.const(attrs.get_float('on_value', 1.0), dtype)
off_value = tvm.relay.const(attrs.get_float('off_value', 0.0), dtype)
return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
Expand Down Expand Up @@ -1041,6 +1049,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"LinearRegressionOutput" : _mx_linear_regression_output,
"smooth_l1" : _mx_smooth_l1,
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
"one_hot" : _mx_one_hot,
# vision
"_contrib_BilinearResize2D" : _mx_resize,
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,21 @@ def _impl(inputs, attr, params):
return get_relay_op('log')(add_out)
return _impl

def _one_hot():
def _impl(inputs, attr, params):
depth = int(_get_num_param(params, inputs[1]))
dtype = attr['T'].name

on_value = _get_num_param(params, inputs[2])
off_value = _get_num_param(params, inputs[3])
new_inputs = [inputs[0], \
tvm.relay.const(on_value, dtype), \
tvm.relay.const(off_value, dtype)]
return AttrCvt('one_hot',
ignores=['TI'],
extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr)
return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1284,6 +1299,7 @@ def _impl(inputs, attr, params):
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)
_reg.register_schedule("sequence_mask", schedule_injective)
_reg.register_schedule("one_hot", schedule_injective)


# layout_transform
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,3 +748,47 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
[[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]]
"""
return _make.sequence_mask(data, valid_length, mask_value, axis)

def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value.
Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.

Parameters
----------
indices : relay.Expr
Locations to set to on_value.

on_value : relay.Expr
Value to fill at indices.

off_value : relay.Expr
Value to fill at all other positions besides indices.

depth : int
Depth of the one-hot dimension.

axis : int
Axis to fill.

dtype : str
Data type of the output tensor.

Returns
-------
ret : relay.Expr
The one-hot tensor.

Examples
--------
.. code-block:: python

indices = [0, 1, 2]

relay.one_hot(indices, 3) =
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
"""
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
89 changes: 89 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2482,5 +2482,94 @@ Examples::
.set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// relay.one_hot
TVM_REGISTER_NODE_TYPE(OneHotAttrs);

bool OneHotRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [indices, on_value, off_value, result]
CHECK_EQ(types.size(), 4);
const auto* indices = types[0].as<TensorTypeNode>();
CHECK(indices);

const auto param = attrs.as<OneHotAttrs>();
CHECK_GT(param->depth, 0);

Array<IndexExpr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
for (int i = 0; i < ndim; i++) {
if (i == true_axis) {
oshape.push_back(Integer(param->depth));
} else {
oshape.push_back(indices->shape[indices_index++]);
}
}

reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype));
return true;
}

Array<Tensor> OneHotCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<OneHotAttrs>();
CHECK(param != nullptr);
return Array<Tensor> {
topi::one_hot(inputs[0],
inputs[1](),
inputs[2](),
param->depth,
param->axis,
param->dtype)
};
}

Expr MakeOneHot(Expr indices,
Expr on_value,
Expr off_value,
int depth,
int axis,
DataType dtype) {
auto attrs = make_node<OneHotAttrs>();
attrs->depth = std::move(depth);
attrs->axis = axis;
attrs->dtype = dtype;
static const Op& op = Op::Get("one_hot");
return CallNode::make(op, {indices, on_value, off_value}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.one_hot")
.set_body_typed(MakeOneHot);

RELAY_REGISTER_OP("one_hot")
.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
other locations take value 0. Final dimension is <indices dimensions> x depth.

**indices** Locations to set to 1.

**on_value** Value to fill at indices.

**off_value** Value to fill at all other positions besides indices.

**depth** Depth of the one-hot dimension.

**axis** Axis to fill.

**dtype**)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.OneHotAttrs")
.set_num_inputs(3)
.add_argument("indices", "Tensor", "Locations to set to on_value.")
.add_argument("on_value", "Expr", "Value to fill at indices.")
.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
.set_support_level(10)
.add_type_rel("OneHot", OneHotRel)
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

} // namespace relay
} // namespace tvm
20 changes: 20 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,25 @@ def verify(shape, axis=-1):
verify((2, 5), axis=0)
verify((2, 5, 6))

def test_forward_one_hot():
def verify(indices_shape, depth, on_value, off_value, dtype):
x = np.random.randint(0, 5, size=indices_shape)
ref_res = mx.nd.one_hot(mx.nd.array(x), depth, on_value, off_value, dtype)
mx_sym = mx.sym.one_hot(mx.sym.var("x"), depth, on_value, off_value, dtype)
shape_dict = {"x": x.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x.astype("float32"))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
verify((3,), 3, 1, 0, "int32")
verify((3,), 3, 1.0, 0.0, "float32")
verify((2, 2), 5, 2, -2, "int32")
verify((2, 2), 5, 0.5, -0.5, "float32")
verify((3, 2, 4, 5), 6, 1, 0, "int32")
verify((3, 2, 4, 5), 6, 1.0, 0.0, "float32")

if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -825,3 +844,4 @@ def verify(shape, axis=-1):
test_forward_contrib_div_sqrt_dim()
test_forward_batch_norm()
test_forward_layer_norm()
test_forward_one_hot()
19 changes: 19 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,6 +2158,24 @@ def test_placeholder():
compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0',
init_global_variables=True)

#######################################################################
# OneHot
# ----------------------
def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype):
inp_array1 = np.random.randint(0, 5, size=indices_shape)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype)
compare_tf_with_tvm(inp_array1, in1.name, out.name)

def test_forward_one_hot():
_test_forward_one_hot((3,), 3, 1, 0, -1, "int32")
_test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
_test_forward_one_hot((2, 2), 5, 2, -2, 0, "int32")
_test_forward_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32")
_test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")


#######################################################################
# Main
Expand Down Expand Up @@ -2193,6 +2211,7 @@ def test_placeholder():
test_forward_right_shift()
test_forward_left_shift()
test_forward_truncatemod()
test_forward_one_hot()

# Activations
test_forward_sigmoid()
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,45 @@ def _verify(data_shape, mask_value, axis, dtype, itype):
_verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64')
_verify((5, 8, 3), 0.1, 1, 'float64', 'float32')

def test_one_hot():
def _get_oshape(indices_shape, depth, axis):
oshape = []
true_axis = len(indices_shape) if axis == -1 else axis
ndim = len(indices_shape) + 1
indices_index = 0
for i in range(0, ndim):
if i == true_axis:
oshape.append(depth)
else:
oshape.append(indices_shape[indices_index])
indices_index += 1

return oshape

def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
indices = relay.var("indices", relay.TensorType(indices_shape, "int32"))
on_value_const = relay.const(on_value)
off_value_const = relay.const(off_value)
out = relay.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
checked = run_infer_type(out)
assert checked.checked_type == relay.ty.TensorType(_get_oshape(indices_shape, depth, axis), dtype)
func = relay.Function([indices], out)
indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32")
out_np = topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
out_relay = intrp.evaluate(func)(indices_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)

_verify((3,), 3, 1, 0, -1, "int32")
_verify((3,), 3, 1.0, 0.0, -1, "float32")
_verify((2, 2), 5, 2, -2, 0, "int32")
_verify((2, 2), 5, 0.5, -0.5, 1, "float32")
_verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")

if __name__ == "__main__":
test_adaptive_pool2d()
test_collapse_sum_like()
Expand All @@ -306,4 +345,5 @@ def _verify(data_shape, mask_value, axis, dtype, itype):
test_shape_of()
test_sequence_mask()
test_ndarray_size()
test_one_hot()

Loading