Skip to content

Commit

Permalink
[TOPI][Relay][TensorFlow] Add OneHot operator (apache#3781)
Browse files Browse the repository at this point in the history
* Add one-hot to Relay

* topi implementation

* Working

* add topi test

* Add TF test

* Fix check

* fix linting issues

* fix documentation

* Fix documentation

* Add support for on_value, off_value, axis, dtype

* Add full support for axis

* Fix compute and update test_forward

* Move on_value and off_value to inputs

* Add topi test

* Update tests

* Update docs

* Fix style

* re-enable tests

* Add one_hot to mxnet converter
  • Loading branch information
soiferj authored and wweic committed Sep 16, 2019
1 parent f375386 commit 18c1fa0
Show file tree
Hide file tree
Showing 17 changed files with 473 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ List of operators
topi.argsort
topi.topk
topi.sequence_mask
topi.one_hot


List of schedules
Expand Down Expand Up @@ -173,6 +174,7 @@ topi
.. autofunction:: topi.argsort
.. autofunction:: topi.topk
.. autofunction:: topi.sequence_mask
.. autofunction:: topi.one_hot

topi.nn
~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.nn.batch_matmul
tvm.relay.contrib.adaptive_max_pool2d
tvm.relay.contrib.adaptive_avg_pool2d
tvm.relay.one_hot


**Level 11: Dialect Operators**
Expand Down Expand Up @@ -350,6 +351,7 @@ Level 10 Definitions
.. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
.. autofunction:: tvm.relay.one_hot


Level 11 Definitions
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,22 @@ struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
}
};

/*! \brief Attributes used in one-hot operator */
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
int depth;
int axis;
DataType dtype;

TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") {
TVM_ATTR_FIELD(depth).set_default(1)
.describe("Depth of the one hot dimension.");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis to fill.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Output data type.");
}
}; // struct OneHotAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,14 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
ret.append(_op.stack(inputs, axis=0))
return ret

def _mx_one_hot(inputs, attrs):
indices = inputs[0].astype('int32')
depth = attrs.get_int('depth', 0)
dtype = attrs.get_str('dtype', 'int32')
on_value = tvm.relay.const(attrs.get_float('on_value', 1.0), dtype)
off_value = tvm.relay.const(attrs.get_float('off_value', 0.0), dtype)
return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
Expand Down Expand Up @@ -1041,6 +1049,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"LinearRegressionOutput" : _mx_linear_regression_output,
"smooth_l1" : _mx_smooth_l1,
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
"one_hot" : _mx_one_hot,
# vision
"_contrib_BilinearResize2D" : _mx_resize,
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,21 @@ def _impl(inputs, attr, params):
return get_relay_op('log')(add_out)
return _impl

def _one_hot():
def _impl(inputs, attr, params):
depth = int(_get_num_param(params, inputs[1]))
dtype = attr['T'].name

on_value = _get_num_param(params, inputs[2])
off_value = _get_num_param(params, inputs[3])
new_inputs = [inputs[0], \
tvm.relay.const(on_value, dtype), \
tvm.relay.const(off_value, dtype)]
return AttrCvt('one_hot',
ignores=['TI'],
extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr)
return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1284,6 +1299,7 @@ def _impl(inputs, attr, params):
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)
_reg.register_schedule("sequence_mask", schedule_injective)
_reg.register_schedule("one_hot", schedule_injective)


# layout_transform
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,3 +748,47 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
[[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]]
"""
return _make.sequence_mask(data, valid_length, mask_value, axis)

def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value.
Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.
Parameters
----------
indices : relay.Expr
Locations to set to on_value.
on_value : relay.Expr
Value to fill at indices.
off_value : relay.Expr
Value to fill at all other positions besides indices.
depth : int
Depth of the one-hot dimension.
axis : int
Axis to fill.
dtype : str
Data type of the output tensor.
Returns
-------
ret : relay.Expr
The one-hot tensor.
Examples
--------
.. code-block:: python
indices = [0, 1, 2]
relay.one_hot(indices, 3) =
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
"""
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
89 changes: 89 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2482,5 +2482,94 @@ Examples::
.set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// relay.one_hot
TVM_REGISTER_NODE_TYPE(OneHotAttrs);

bool OneHotRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [indices, on_value, off_value, result]
CHECK_EQ(types.size(), 4);
const auto* indices = types[0].as<TensorTypeNode>();
CHECK(indices);

const auto param = attrs.as<OneHotAttrs>();
CHECK_GT(param->depth, 0);

Array<IndexExpr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
for (int i = 0; i < ndim; i++) {
if (i == true_axis) {
oshape.push_back(Integer(param->depth));
} else {
oshape.push_back(indices->shape[indices_index++]);
}
}

reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype));
return true;
}

Array<Tensor> OneHotCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<OneHotAttrs>();
CHECK(param != nullptr);
return Array<Tensor> {
topi::one_hot(inputs[0],
inputs[1](),
inputs[2](),
param->depth,
param->axis,
param->dtype)
};
}

Expr MakeOneHot(Expr indices,
Expr on_value,
Expr off_value,
int depth,
int axis,
DataType dtype) {
auto attrs = make_node<OneHotAttrs>();
attrs->depth = std::move(depth);
attrs->axis = axis;
attrs->dtype = dtype;
static const Op& op = Op::Get("one_hot");
return CallNode::make(op, {indices, on_value, off_value}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.one_hot")
.set_body_typed(MakeOneHot);

RELAY_REGISTER_OP("one_hot")
.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
other locations take value 0. Final dimension is <indices dimensions> x depth.
**indices** Locations to set to 1.
**on_value** Value to fill at indices.
**off_value** Value to fill at all other positions besides indices.
**depth** Depth of the one-hot dimension.
**axis** Axis to fill.
**dtype**)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.OneHotAttrs")
.set_num_inputs(3)
.add_argument("indices", "Tensor", "Locations to set to on_value.")
.add_argument("on_value", "Expr", "Value to fill at indices.")
.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
.set_support_level(10)
.add_type_rel("OneHot", OneHotRel)
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

} // namespace relay
} // namespace tvm
20 changes: 20 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,25 @@ def verify(shape, axis=-1):
verify((2, 5), axis=0)
verify((2, 5, 6))

def test_forward_one_hot():
def verify(indices_shape, depth, on_value, off_value, dtype):
x = np.random.randint(0, 5, size=indices_shape)
ref_res = mx.nd.one_hot(mx.nd.array(x), depth, on_value, off_value, dtype)
mx_sym = mx.sym.one_hot(mx.sym.var("x"), depth, on_value, off_value, dtype)
shape_dict = {"x": x.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x.astype("float32"))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
verify((3,), 3, 1, 0, "int32")
verify((3,), 3, 1.0, 0.0, "float32")
verify((2, 2), 5, 2, -2, "int32")
verify((2, 2), 5, 0.5, -0.5, "float32")
verify((3, 2, 4, 5), 6, 1, 0, "int32")
verify((3, 2, 4, 5), 6, 1.0, 0.0, "float32")

if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -825,3 +844,4 @@ def verify(shape, axis=-1):
test_forward_contrib_div_sqrt_dim()
test_forward_batch_norm()
test_forward_layer_norm()
test_forward_one_hot()
19 changes: 19 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,6 +2158,24 @@ def test_placeholder():
compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0',
init_global_variables=True)

#######################################################################
# OneHot
# ----------------------
def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype):
inp_array1 = np.random.randint(0, 5, size=indices_shape)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype)
compare_tf_with_tvm(inp_array1, in1.name, out.name)

def test_forward_one_hot():
_test_forward_one_hot((3,), 3, 1, 0, -1, "int32")
_test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
_test_forward_one_hot((2, 2), 5, 2, -2, 0, "int32")
_test_forward_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32")
_test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")


#######################################################################
# Main
Expand Down Expand Up @@ -2193,6 +2211,7 @@ def test_placeholder():
test_forward_right_shift()
test_forward_left_shift()
test_forward_truncatemod()
test_forward_one_hot()

# Activations
test_forward_sigmoid()
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,45 @@ def _verify(data_shape, mask_value, axis, dtype, itype):
_verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64')
_verify((5, 8, 3), 0.1, 1, 'float64', 'float32')

def test_one_hot():
def _get_oshape(indices_shape, depth, axis):
oshape = []
true_axis = len(indices_shape) if axis == -1 else axis
ndim = len(indices_shape) + 1
indices_index = 0
for i in range(0, ndim):
if i == true_axis:
oshape.append(depth)
else:
oshape.append(indices_shape[indices_index])
indices_index += 1

return oshape

def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
indices = relay.var("indices", relay.TensorType(indices_shape, "int32"))
on_value_const = relay.const(on_value)
off_value_const = relay.const(off_value)
out = relay.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
checked = run_infer_type(out)
assert checked.checked_type == relay.ty.TensorType(_get_oshape(indices_shape, depth, axis), dtype)
func = relay.Function([indices], out)
indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32")
out_np = topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
out_relay = intrp.evaluate(func)(indices_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)

_verify((3,), 3, 1, 0, -1, "int32")
_verify((3,), 3, 1.0, 0.0, -1, "float32")
_verify((2, 2), 5, 2, -2, 0, "int32")
_verify((2, 2), 5, 0.5, -0.5, 1, "float32")
_verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")

if __name__ == "__main__":
test_adaptive_pool2d()
test_collapse_sum_like()
Expand All @@ -306,4 +345,5 @@ def _verify(data_shape, mask_value, axis, dtype, itype):
test_shape_of()
test_sequence_mask()
test_ndarray_size()
test_one_hot()

Loading

0 comments on commit 18c1fa0

Please sign in to comment.