Skip to content

Commit

Permalink
[Relay][Dynamic] OneHot operation (apache#6209)
Browse files Browse the repository at this point in the history
* Dynamic OneHot Op

* refactor dynamic_to_static

* add onehot to dynamic_to_static pass
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Sep 3, 2020
1 parent 03c3e97 commit 303d155
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 61 deletions.
19 changes: 11 additions & 8 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1421,22 +1421,25 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
* \param depth depth of the one-hot dimension.
* \param axis axis to fill.
* \param dtype data type of the output tensor.
* \param oshape shape of the output tensor.
* \param name output tensor name.
* \param tag output tensor tag.
* \return one-hot tensor.
*/
inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
int depth, int axis, const DataType& dtype,
Array<PrimExpr> oshape = Array<PrimExpr>(),
const std::string name = "T_one_hot", const std::string tag = kInjective) {
Array<PrimExpr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
int true_axis = (axis == -1) ? indices->shape.size() : axis;
for (int i = 0; i < ndim; i++) {
if (i == true_axis) {
oshape.push_back(Integer(depth));
} else {
oshape.push_back(indices->shape[indices_index++]);
if (oshape.size() == 0) {
int ndim = indices->shape.size() + 1;
int indices_index = 0;
for (int i = 0; i < ndim; i++) {
if (i == true_axis) {
oshape.push_back(Integer(depth));
} else {
oshape.push_back(indices->shape[indices_index++]);
}
}
}

Expand Down
35 changes: 29 additions & 6 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
_reg.register_broadcast_schedule("dyn.broadcast_to")
_reg.register_injective_schedule("dyn.reshape")
_reg.register_broadcast_schedule("dyn.tile")
_reg.register_injective_schedule("dyn.one_hot")


@script
def _reshape_shape_func_input_data(data, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
out = output_tensor((ndim, ), "int64")
data_shape = allocate((len(data.shape), ), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
Expand Down Expand Up @@ -59,7 +61,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx + 1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
Expand All @@ -82,14 +84,15 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
out[infer_idx] = old_size // new_size
return out


@_reg.register_shape_func("dyn.reshape", True)
def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]


@script
def _tile_shape_func(data, reps, ndim, tndim, rndim):
out = output_tensor((tndim,), "int64")
out = output_tensor((tndim, ), "int64")

if ndim == rndim:
for i in const_range(tndim):
Expand Down Expand Up @@ -120,5 +123,25 @@ def tile_shape_func(attrs, inputs, _):
ndim = len(inputs[0].shape)
rndim = inputs[1].shape[0].value
tndim = ndim if ndim > rndim else rndim
return [_tile_shape_func(inputs[0], reps, convert(ndim),
convert(tndim), convert(rndim))]
return [_tile_shape_func(inputs[0], reps, convert(ndim), convert(tndim), convert(rndim))]


@script
def _onehot_shape_func(dshape, k, axis):
ndim = len(dshape) + 1
out = output_tensor((ndim, ), "int64")
for i in const_range(axis):
out[i] = int64(dshape[i])
out[axis] = int64(k[0])
for j in const_range(axis + 1, ndim):
out[j] = int64(dshape[j - 1])
return out


@_reg.register_shape_func("dyn.one_hot", True)
def one_hot_shape_func(attrs, inputs, _):
"""
Shape function for dyn.one_hot op.
"""
axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis
return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))]
15 changes: 14 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def squeeze(data, axis=None):
"""
return _make.squeeze(data, axis)


def reshape(data, newshape):
"""Reshape the input array.
Expand Down Expand Up @@ -228,6 +229,7 @@ def reshape(data, newshape):
newshape = tempshape
return _make.reshape(data, list(newshape))


def argwhere(condition):
"""Find the indices of elements of a tensor that are
non-zero.
Expand All @@ -251,6 +253,7 @@ def argwhere(condition):
"""
return _make.argwhere(condition)


def scatter(data, indices, updates, axis):
"""Update data at positions defined by indices with values in updates
Expand All @@ -275,6 +278,7 @@ def scatter(data, indices, updates, axis):
"""
return _make.scatter(data, indices, updates, axis)


def scatter_add(data, indices, updates, axis):
"""Update data by adding values in updates at positions defined by indices
Expand All @@ -299,6 +303,7 @@ def scatter_add(data, indices, updates, axis):
"""
return _make.scatter_add(data, indices, updates, axis)


def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
Expand Down Expand Up @@ -442,6 +447,7 @@ def arange(start, stop=None, step=None, dtype="float32"):

return _make.arange(start, stop, step, dtype)


def meshgrid(data, indexing="ij"):
"""Create coordinate matrices from coordinate vectors.
Expand Down Expand Up @@ -482,6 +488,7 @@ def meshgrid(data, indexing="ij"):
ret_size = len(data)
return TupleWrapper(_make.meshgrid(Tuple(data), indexing), ret_size)


def repeat(data, repeats, axis):
"""Repeats elements of an array.
By default, repeat flattens the input array into 1-D and then repeats the elements.
Expand Down Expand Up @@ -668,6 +675,7 @@ def where(condition, x, y):
"""
return _make.where(condition, x, y)


def broadcast_to(data, shape):
"""Return a scalar value array with the same type, broadcast to
the provided shape.
Expand All @@ -693,6 +701,7 @@ def broadcast_to(data, shape):
shape = list(shape)
return _make.broadcast_to(data, shape)


def broadcast_to_like(data, broadcast_type):
"""Return a scalar value array with the same shape and type as the input array.
Expand Down Expand Up @@ -1053,6 +1062,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
"""
return _make.sequence_mask(data, valid_length, mask_value, axis)


def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take value on_value,
Expand All @@ -1070,7 +1080,7 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
off_value : relay.Expr
Value to fill at all other positions besides indices.
depth : int
depth : int or relay.Expr
Depth of the one-hot dimension.
axis : int
Expand All @@ -1095,6 +1105,8 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
[0, 1, 0],
[0, 0, 1]]
"""
if isinstance(depth, Expr):
return _dyn_make.one_hot(indices, on_value, off_value, depth, axis, dtype)
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)


Expand All @@ -1120,6 +1132,7 @@ def unravel_index(indices, shape):

return _make.unravel_index(indices, shape)


def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0):
"""Converts a sparse representation into a dense tensor.
Expand Down
70 changes: 70 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,76 @@ RELAY_REGISTER_OP("dyn.ones")
.set_support_level(3)
.add_type_rel("DynamicInitOp", InitOpRel);

bool OneHotRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [indices, on_value, off_value, result]
CHECK_EQ(types.size(), 5);
const auto* indices = types[0].as<TensorTypeNode>();
CHECK(indices);

const auto param = attrs.as<OneHotAttrs>();

Array<IndexExpr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
for (int i = 0; i < ndim; i++) {
if (i == true_axis) {
oshape.push_back(Any());
} else {
oshape.push_back(indices->shape[indices_index++]);
}
}

reporter->Assign(types[4], TensorType(oshape, param->dtype));
return true;
}

Array<te::Tensor> OneHotCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<OneHotAttrs>();
CHECK(param != nullptr);
const auto* out_ttype = out_type.as<TensorTypeNode>();
return Array<te::Tensor>{topi::one_hot(inputs[0], inputs[1](), inputs[2](), -1, param->axis,
param->dtype, out_ttype->shape)};
}

Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, Expr depth, int axis, DataType dtype) {
auto attrs = make_object<OneHotAttrs>();
attrs->axis = axis;
attrs->dtype = dtype;
static const Op& op = Op::Get("dyn.one_hot");
return Call(op, {indices, on_value, off_value, depth}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.one_hot").set_body_typed(MakeOneHot);

RELAY_REGISTER_OP("dyn.one_hot")
.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
other locations take value 0. Final dimension is <indices dimensions> x depth.
**indices** Locations to set to 1.
**on_value** Value to fill at indices.
**off_value** Value to fill at all other positions besides indices.
**depth** Depth of the one-hot dimension.
**axis** Axis to fill.
**dtype**)code" TVM_ADD_FILELINE)
.set_attrs_type<OneHotAttrs>()
.set_num_inputs(4)
.add_argument("indices", "Tensor", "Locations to set to on_value.")
.add_argument("on_value", "Expr", "Value to fill at indices.")
.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
.add_argument("depth", "Expr", "Value to fill at all other positions besides indices.")
.set_support_level(10)
.add_type_rel("DynOneHot", OneHotRel)
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

} // namespace dyn
} // namespace relay
} // namespace tvm
2 changes: 2 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool

Expr MakeZeros(Array<Integer> shape, DataType dtype);

Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_MAKE_OP_H_
Loading

0 comments on commit 303d155

Please sign in to comment.