Skip to content

Commit

Permalink
[Relay, TOPI] add onehot op support
Browse files Browse the repository at this point in the history
  • Loading branch information
honghua.cao committed Aug 21, 2019
1 parent ebda258 commit be24810
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 2 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,18 @@ struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
}
};

/*! \brief Attributes used in one_hot operators */
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
Integer depth;
Integer axis;

TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") {
TVM_ATTR_FIELD(depth).set_default(NullValue<Integer>())
.describe("Defining the depth of the one hot dimension.");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("The axis at which the input arrays are expand dims.");
}
}; // struct OneHotAttrs
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,21 @@ def _transform_mask(stride_dim, ellipsis_mask):
return _op.reshape(out, newshape=tuple(final_output))
return _impl

def _one_hot():
def _impl(inputs, attr, params):
depth = _get_num_param(params, inputs.pop(1))
on_value = _get_num_param(params, inputs.pop(1))
off_value = _get_num_param(params, inputs.pop(1))
inputs.append(tvm.relay.const(on_value, dtype=on_value.dtype))
inputs.append(tvm.relay.const(off_value, dtype=off_value.dtype))
axis = int(attr["axis"])
new_input = inputs[0:3]
return AttrCvt(op_name="one_hot",
extras={'depth': tvm.const(depth, 'int32'),
'axis': tvm.const(axis, 'int32')},
ignores=['TI'])(new_input, attr)
return _impl

def _pad(name):
def _impl(inputs, attr, params):
padlist = _get_param(params, inputs[1])
Expand Down Expand Up @@ -1284,6 +1299,7 @@ def _impl(inputs, attr, params):
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)
_reg.register_schedule("sequence_mask", schedule_injective)

_reg.register_schedule("one_hot", schedule_injective)

# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,8 @@ class MaxPool2DAttrs(Attrs):
@register_relay_attr_node
class AvgPool2DAttrs(Attrs):
"""Attributes used in avg_pool2d operators"""


@register_relay_attr_node
class OneHotAttrs(Attrs):
"""Attributes used in one_hot operators"""
45 changes: 45 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,51 @@ def arange(start, stop=None, step=None, dtype="float32"):
return _make.arange(start, stop, step, dtype)


def one_hot(data, on_value=None, off_value=None, depth=None, axis=None):
"""Onehot
This operator takes in a 1-D(n) or more dimension tensor and expand the
dimension by the specified depths using the specified value to (n, depth).
Parameters
----------
data: tvm.relay.Expr
The input data to the operator
depth: int
A scalar defining the depth of the one hot dimension.
on_value: tvm.relay.Expr
The input data defining the value to fill in output when indices[j] = i.
off_value: tvm.relay.Expr,
The input data defining the value to fill in output when indices[j] != i.
axis: int, optional
The axis along which to add depth shape. The default axis is -1.
Returns
-------
result : tvm.relay.Expr
The resulting one-hot tensor.
Examples
--------
.. code-block:: python
indices = [0, 2, -1, 1]
depth = 3
relay.one_hot(data, depth,
on_value=5.0, off_value=0.0,
axis=-1) # output: [4 x 3]
# [[5.0, 0.0, 0.0], # one_hot(0)
# [0.0, 0.0, 5.0], # one_hot(2)
# [0.0, 0.0, 0.0], # one_hot(-1)
# [0.0, 5.0, 0.0]] # one_hot(1)
"""

return _make.one_hot(data, depth, on_value, off_value, axis)

def repeat(data, repeats, axis):
"""Repeats elements of an array.
By default, repeat flattens the input array into 1-D and then repeats the elements.
Expand Down
103 changes: 103 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2482,5 +2482,108 @@ Examples::
.set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// one_hot operator
TVM_REGISTER_NODE_TYPE(OneHotAttrs);

bool OneHotRel(const Array<Type>& types,
int num_inputs,
const Attrs& raw_attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto attrs = raw_attrs.as<OneHotAttrs>();
CHECK(attrs != nullptr);

const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "one_hot: expect input data type to be TensorType but get "
<< types[0];
return false;
}

int depth = static_cast<int>(attrs->depth->value);
CHECK_GT(depth, 0)
<< "Invalid one_hot attributes (depth): " << attrs->depth;
const auto* on_value = types[1].as<TensorTypeNode>();
const auto* off_value = types[2].as<TensorTypeNode>();
if (on_value == nullptr || off_value == nullptr) {
return false;
}

CHECK_EQ(on_value->shape.size(), 0) << "on_value should be a scalar";
CHECK_EQ(off_value->shape.size(), 0) << "off_value should be a scalar";

int axis;
if (!attrs->axis.defined()) {
axis = static_cast<int>(data->shape.size());
} else {
axis = static_cast<int>(attrs->axis->value);
CHECK_GE(axis, -1)
<< "axis should be greater equal than -1.";
CHECK_LT(axis, static_cast<int>(data->shape.size()))
<< "axis should be within the input dimension range.";
if (axis < 0) {
axis = static_cast<int>(data->shape.size());
}
}

std::vector<IndexExpr> oshape;
const auto ndim_data = static_cast<int>(data->shape.size());

oshape.reserve(ndim_data + 1);
for (int i = 0; i < axis; ++i) {
oshape.emplace_back(data->shape[i]);
}
if (axis == ndim_data) {
oshape.emplace_back(depth);
} else {
oshape.emplace_back(depth);
for (int i = axis; i < ndim_data; ++i) {
oshape.emplace_back(data->shape[i]);
}
}
reporter->Assign(types[3], TensorTypeNode::make(Array<IndexExpr>(oshape),
on_value->dtype));
return true;
}

Array<Tensor> OneHotCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<OneHotAttrs>();
CHECK(param != nullptr);
Tensor on_value = inputs[1];
Tensor off_value = inputs[2];

return Array<Tensor>{ topi::one_hot(inputs[0], param->depth, on_value, off_value, param->axis) };
}

Expr MakeOneHot(Expr data,
Integer depth,
Expr on_value,
Expr off_value,
Integer axis) {
auto attrs = make_node<OneHotAttrs>();
attrs->depth = std::move(depth);
attrs->axis = std::move(axis);
static const Op& op = Op::Get("one_hot");
return CallNode::make(op, {data, on_value, off_value}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.one_hot")
.set_body_typed(MakeOneHot);

RELAY_REGISTER_OP("one_hot")
.describe(R"code(Returns one-hot array within a given interval-depth.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.OneHotAttrs")
.set_num_inputs(3)
.set_support_level(3)
.add_type_rel("OneHot", OneHotRel)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
} // namespace relay
} // namespace tvm
45 changes: 45 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,49 @@ def verify_gather_nd(xshape, yshape, y_data):
verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])

def test_one_hot_infer_type():
def verify_one_hot(dshape, depth, oshape, axis=None):
input = relay.var("input", relay.TensorType(dshape, "int32"))
on_value = relay.var("on_value", relay.scalar_type("float32"))
off_value = relay.var("off_value", relay.scalar_type("float32"))
y = relay.one_hot(input, on_value=on_value, off_value=off_value, depth=depth, axis=axis)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(oshape, "float32")

d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3")

verify_one_hot((d1,), 2, (d1,2), -1)
verify_one_hot((4,), 3, (4, 3))
verify_one_hot((3, 3), 4, (3, 3 ,4))
verify_one_hot((d1, d2), 5, (d1, d2, 5), -1)

def test_one_hot():
def verify_one_hot(src_shape, depth, on_value_data, off_value_data, axis=None):
data_dtype = "int32"
value_dtype = "float32"
shape_size = 1
for i in range(len(src_shape)):
shape_size = shape_size * src_shape[i]
input_data = np.arange(shape_size, dtype=data_dtype).reshape((src_shape))
input = relay.var("input", relay.TensorType(input_data.shape, data_dtype))
on_value = relay.var("on_value", relay.scalar_type(value_dtype))
off_value = relay.var("off_value", relay.scalar_type(value_dtype))
z = relay.one_hot(input, on_value=on_value, off_value=off_value, depth=depth, axis=axis)

func = relay.Function([input, on_value, off_value], z)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
on_value_npy = np.array(on_value_data, dtype=value_dtype)
off_value_npy = np.array(off_value_data, dtype=value_dtype)
op_res = intrp.evaluate(func)(input_data, on_value_npy, off_value_npy)
ref_res = on_value_data * np.eye(depth)[input_data]
ref_res[ref_res == 0] = off_value_data
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_one_hot((4,), 4, 1.0, 0.0, -1)

if __name__ == "__main__":
test_arange()
test_cast()
Expand Down Expand Up @@ -715,3 +758,5 @@ def verify_gather_nd(xshape, yshape, y_data):
test_tile()
test_repeat()
test_gather_nd()
test_one_hot_infer_type()
test_one_hot()
56 changes: 56 additions & 0 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,62 @@ inline Tensor take(const Tensor& a,
}


/*!
* \brief OneHot elements from an flattened input array of depth along an axis.
*
* \param a The source array.
* \param depth The depth of the one hot dimension to expand.
* \param on_value The value to fill in output when indices[j] = i.
* \param off_value The value to fill in output when indices[j] != i.
* \param axis The axis over which to select values. By default,
* the flattened input array is used.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the take operation
*/
inline Tensor one_hot(const Tensor& a,
int depth,
const Tensor& on_value,
const Tensor& off_value,
int axis,
std::string name = "T_one_hot",
std::string tag = kInjective) {
int input_shape = static_cast<int>(a->shape.size());
CHECK_GE(axis, -1) << "axis out of bounds, must >= -1" << axis;
CHECK_LT(axis, static_cast<int>(a->shape.size()))
<< "axis out of bounds, must < a->shape.size()"
<< "(" << axis << "," << a->shape.size() << ")";

Array<Expr> out_shape;
for (int i = 0; i < input_shape; ++i) {
if (axis == static_cast<int>(i)) {
out_shape.push_back(depth);
out_shape.push_back(a->shape[i]);
} else {
out_shape.push_back(a->shape[i]);
}
}
if (axis < 0) {
out_shape.push_back(depth);
axis = input_shape;
}

return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> real_indices;
for (int j = 0; j < axis; ++j) {
real_indices.push_back(indices[j]);
}
if (axis < input_shape) {
for (int j = axis + 1; j < static_cast<int>(indices.size()); ++j) {
real_indices.push_back(indices[j]);
}
}
Expr ret = tvm::ir::Select::make(a(real_indices) == indices[axis], on_value(), off_value());
return ret;
}, name, tag);
}
/*!
* \brief Mask the out-of-boundary elements of each sequence.
*
Expand Down
27 changes: 27 additions & 0 deletions topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,30 @@ def where(condition, x, y):
A Tensor selected from x or y depending on condition.
"""
return cpp.where(condition, x, y)

def one_hot(data, depth, on_value, off_value, axis=-1):
"""Creates a one_hot tensor from an input tensor along the axis.
Parameters
----------
data : tvm.Tensor
n-D input, can be any layout.
depth : Expr.Constant
A scalar defining the depth of the one hot dimension.
on_value : tvm.Tensor
The input data defining the value to fill in output when indices[j] = i.
off_value : tvm.Tensor
The input data defining the value to fill in output when indices[j] != i.
axis : Expr.Constant, optional
The axis along which to add depth shape. The default axis is -1.
Returns
-------
result : tvm.Tensor
The resulting one-hot tensor.
"""
return cpp.one_hot(data, depth, on_value, off_value, axis)
8 changes: 7 additions & 1 deletion topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]);
});

TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
Expand Down Expand Up @@ -358,6 +357,13 @@ TVM_REGISTER_GLOBAL("topi.take")
}
});

TVM_REGISTER_GLOBAL("topi.one_hot")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int depth = args[1];
int axis = args[4];
*rv = one_hot(args[0], depth, args[2], args[3], axis);
});

TVM_REGISTER_GLOBAL("topi.sequence_mask")
.set_body([](TVMArgs args, TVMRetValue *rv) {
double pad_val = args[2];
Expand Down
Loading

0 comments on commit be24810

Please sign in to comment.