From 594171fe2bcb692c887fa8c7df36ef18ffcd34ec Mon Sep 17 00:00:00 2001 From: Jon Date: Wed, 14 Aug 2019 16:33:52 -0700 Subject: [PATCH 01/19] Add one-hot to Relay --- include/tvm/relay/attrs/transform.h | 10 +++++ python/tvm/relay/frontend/tensorflow.py | 14 ++++++ python/tvm/relay/op/_transform.py | 1 + python/tvm/relay/op/transform.py | 20 +++++++++ src/relay/op/tensor/transform.cc | 57 +++++++++++++++++++++++++ 5 files changed, 102 insertions(+) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index e43fd5f7a2e7..79577d1a841b 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -298,6 +298,16 @@ struct NdarraySizeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in one-hot operator */ +struct OneHotAttrs : public tvm::AttrsNode { + int depth; + + TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") { + TVM_ATTR_FIELD(depth).set_default(1) + .describe("Depth of the one hot dimension."); + } +}; // struct OneHotAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index bbc0fec67bf6..0e43a6526c31 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1212,6 +1212,19 @@ def _impl(inputs, attr, params): return get_relay_op('log')(add_out) return _impl +def one_hot(): + def _impl(inputs, attr, params): + one_hot = AttrCvt('one_hot', + ignores=['on_value', 'off_value', 'axis', 'dtype']) + + out_dtype = attr.get("T", None) + if out_dtype is None: + out_dtype_name = "float32" + else: + out_dtype_name = out_dtype.name + return _op.cast(one_hot, out_dtype_name) + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1284,6 +1297,7 @@ def _impl(inputs, attr, params): 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), 'NotEqual' : _broadcast('not_equal'), + 'OneHot' : _one_hot(), 'Pack' : _pack(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 51e761516eed..a4c937524c70 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -52,6 +52,7 @@ _reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective) _reg.register_schedule("sequence_mask", schedule_injective) +_reg.register_schedule("one_hot", schedule_injective) # layout_transform diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 5d8d28006ecb..d278b595fa0d 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -748,3 +748,23 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): [[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]] """ return _make.sequence_mask(data, valid_length, mask_value, axis) + +def one_hot(indices, depth): + """ + Returns a one-hot tensor where the locations repsented by indices take value 1, + other locations take value 0. + + Parameters + ---------- + indices : relay.Expr + Locations to set to 1. + + depth : int + Depth of the one-hot dimension. + + Returns + ------- + ret : relay.Expr + The one-hot tensor. + """ + return _make.one_hot(indices, depth) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 03a92b35d396..e882879b41df 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2482,5 +2482,62 @@ Examples:: .set_attr("FTVMCompute", SequenceMaskCompute) .set_attr("TOpPattern", kInjective); +// relay.one_hot +TVM_REGISTER_NODE_TYPE(OneHotAttrs); + +bool OneHotRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [indices, result] + CHECK_EQ(types.size(), 2); + const auto* indices = types[0].as(); + CHECK(indices); + + const auto param = attrs.as(); + CHECK(param->depth != 0); + + Array output_shape(indices->shape); + output_shape.push_back(param->depth); + + reporter->Assign(types[1], TensorTypeNode::make(output_shape, indices->dtype)); + return true; +} + +Array OneHotCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + return Array{ topi::one_hot(inputs[0], param->depth) }; +} + +Expr MakeOneHot(Expr indices, + int depth) { + auto attrs = make_node(); + attrs->depth = std::move(depth); + static const Op& op = Op::Get("one_hot"); + return CallNode::make(op, {indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.one_hot") +.set_body_typed(MakeOneHot); + +RELAY_REGISTER_OP("one_hot") +.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, + other locations take value 0. + + **indices** Locations to set to 1. + + **depth** Depth of the one-hot dimension.)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.OneHotAttrs") +.set_num_inputs(1) +.add_argument("indices", "Tensor", "Locations to set to 1.") +.set_support_level(10) +.add_type_rel("OneHot", OneHotRel) +.set_attr("FTVMCompute", OneHotCompute) +.set_attr("TOpPattern", kOutEWiseFusable); + } // namespace relay } // namespace tvm From 17233a0bb82ffc6846230799b03ee3c6d4d01f44 Mon Sep 17 00:00:00 2001 From: Jon Date: Wed, 14 Aug 2019 17:00:04 -0700 Subject: [PATCH 02/19] topi implementation --- python/tvm/relay/frontend/tensorflow.py | 2 +- topi/include/topi/transform.h | 24 ++++++++++++++++++++++++ topi/python/topi/transform.py | 20 ++++++++++++++++++++ topi/src/topi.cc | 5 +++++ 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 0e43a6526c31..a8e76ce29d5c 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1212,7 +1212,7 @@ def _impl(inputs, attr, params): return get_relay_op('log')(add_out) return _impl -def one_hot(): +def _one_hot(): def _impl(inputs, attr, params): one_hot = AttrCvt('one_hot', ignores=['on_value', 'off_value', 'axis', 'dtype']) diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index e8a65b05a42c..10212b8ae29c 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1247,5 +1247,29 @@ inline Tensor ndarray_size(const Tensor& src, }, name, tag); } +/*! + * \brief Returns a one-hot tensor where the locations repsented by indices take value 1, + other locations take value 0. + * \param indices locations to set to 1. + * \param depth depth of the one-hot dimension. + * \return one-hot tensor. + */ +inline Tensor one_hot(const Tensor& indices, + int depth, + const std::string name = "T_one_hot", + const std::string tag = kInjective) { + Array out_shape = indices->shape; + out_shape.push_back(depth); + return compute(out_shape, [&](const Array& iter_vars) { + Array outer_indices; + for (auto i = 0; i < iter_vars.size(); i++) { + outer_indices.push_back(iter_vars[i]); + } + + auto idx = iter_vars[iter_vars.size() - 1]; + return tvm::if_then_else(indices(outer_indices) == idx, 1, 0); + }, name, tag); +} + } // namespace topi #endif // TOPI_TRANSFORM_H_ diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 5e87933c2806..3fec90e29cb6 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -518,3 +518,23 @@ def where(condition, x, y): A Tensor selected from x or y depending on condition. """ return cpp.where(condition, x, y) + +def one_hot(indices, depth): + """ + Returns a one-hot tensor where the locations repsented by indices take value 1, + other locations take value 0. + + Parameters + ---------- + indices : relay.Expr + Locations to set to 1. + + depth : int + Depth of the one-hot dimension. + + Returns + ------- + ret : relay.Expr + The one-hot tensor. + """ + return cpp.one_hot(indices, depth) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 799b660df3b8..a3cd63f0cdbb 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -417,6 +417,11 @@ TVM_REGISTER_GLOBAL("topi.strided_slice") *rv = strided_slice(args[0], args[1], args[2], args[3]); }); +TVM_REGISTER_GLOBAL("topi.one_hot") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = one_hot(args[0], args[1]); + }); + /* Ops from nn/upsampling.h */ TVM_REGISTER_GLOBAL("topi.nn.upsampling") .set_body([](TVMArgs args, TVMRetValue *rv) { From f76e4466eaf42c82a7c162e471d168cc686789cb Mon Sep 17 00:00:00 2001 From: Jon Date: Thu, 15 Aug 2019 10:00:04 -0700 Subject: [PATCH 03/19] Working --- python/tvm/relay/frontend/tensorflow.py | 7 ++++++- python/tvm/relay/op/transform.py | 13 ++++++++++++- src/relay/op/tensor/transform.cc | 2 +- topi/include/topi/transform.h | 2 +- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a8e76ce29d5c..15cc83076d68 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1214,8 +1214,13 @@ def _impl(inputs, attr, params): def _one_hot(): def _impl(inputs, attr, params): + indices = inputs[0] + depth = int(_infer_value(inputs[1], params).asnumpy()[0]) + + new_inputs = [inputs[0]] one_hot = AttrCvt('one_hot', - ignores=['on_value', 'off_value', 'axis', 'dtype']) + ignores=['axis', 'TI'], + extras={ 'depth' : depth })([inputs[0]], attr) out_dtype = attr.get("T", None) if out_dtype is None: diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index d278b595fa0d..71ef6f7996df 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -752,7 +752,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): def one_hot(indices, depth): """ Returns a one-hot tensor where the locations repsented by indices take value 1, - other locations take value 0. + other locations take value 0. Final dimension is x depth. Parameters ---------- @@ -766,5 +766,16 @@ def one_hot(indices, depth): ------- ret : relay.Expr The one-hot tensor. + + Examples + -------- + .. code-block:: python + + indices = [1., 2., 3.] + + relay.one_hot(indices, 2) = + [[[1., 0., 0.]], + [[0., 1., 0.]], + [[0., 0., 1.]]] """ return _make.one_hot(indices, depth) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e882879b41df..737ae2fafbe0 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2526,7 +2526,7 @@ TVM_REGISTER_API("relay.op._make.one_hot") RELAY_REGISTER_OP("one_hot") .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, - other locations take value 0. + other locations take value 0. Final dimension is x depth. **indices** Locations to set to 1. diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 10212b8ae29c..a70f9b63bef0 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1262,7 +1262,7 @@ inline Tensor one_hot(const Tensor& indices, out_shape.push_back(depth); return compute(out_shape, [&](const Array& iter_vars) { Array outer_indices; - for (auto i = 0; i < iter_vars.size(); i++) { + for (size_t i = 0; i < iter_vars.size() - 1; i++) { outer_indices.push_back(iter_vars[i]); } From 4b2308b5b90e5b504519b044226987a98984a84b Mon Sep 17 00:00:00 2001 From: Jon Date: Thu, 15 Aug 2019 10:27:23 -0700 Subject: [PATCH 04/19] add topi test --- python/tvm/relay/frontend/tensorflow.py | 1 - topi/tests/python/test_topi_transform.py | 29 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 15cc83076d68..85ecd0e4ae7f 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1217,7 +1217,6 @@ def _impl(inputs, attr, params): indices = inputs[0] depth = int(_infer_value(inputs[1], params).asnumpy()[0]) - new_inputs = [inputs[0]] one_hot = AttrCvt('one_hot', ignores=['axis', 'TI'], extras={ 'depth' : depth })([inputs[0]], attr) diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 64305b4a52cc..35e436bba73f 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -473,6 +473,29 @@ def check_device(device): for device in get_all_backend(): check_device(device) +def verify_one_hot(indices_shape, depth): + indices = tvm.placeholder(shape=indices_shape, name="indices", dtype="int32") + one_hot_result = topi.transform.one_hot(indices, depth) + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(one_hot_result) + fn = tvm.build(s, [indices, one_hot_result], device, name="one_hot") + indices_npy = np.random.randint(0, depth, size=indices_shape).astype(indices.dtype) + out_npy = np.eye(depth)[indices_npy] + indices_nd = tvm.nd.array(indices_npy, ctx) + out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(one_hot_result.dtype), ctx) + fn(indices_nd, out_nd) + out_topi = out_nd.asnumpy() + tvm.testing.assert_allclose(out_topi, out_npy) + + for device in get_all_backend(): + check_device(device) + def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) @@ -770,6 +793,11 @@ def check_device(device): for backend in get_all_backend(): check_device(backend) +def test_one_hot(): + verify_one_hot((3,), 3) + verify_one_hot((4,), 3) + verify_one_hot((2, 2), 5) + verify_one_hot((3, 2, 4, 5), 6) if __name__ == "__main__": test_strided_slice() @@ -793,3 +821,4 @@ def check_device(device): test_sequence_mask() test_ndarray_size() test_where_fusion() + test_one_hot() From a861369dd5fc796a435b2590b0ce99c8909c45e1 Mon Sep 17 00:00:00 2001 From: Jon Date: Thu, 15 Aug 2019 10:36:53 -0700 Subject: [PATCH 05/19] Add TF test --- .../frontend/tensorflow/test_forward.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 7b0bcfb7d584..37088684d606 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2158,6 +2158,24 @@ def test_placeholder(): compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True) +####################################################################### +# OneHot +# ---------------------- +def _test_forward_one_hot(indices_shape, depth, out_dtype): + inp_array1 = np.random.randint(0, 5, size=indices_shape) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype) + out = tf.one_hot(in1, depth, dtype=out_dtype) + compare_tf_with_tvm(inp_array1, in1.name, out.name) + +def test_forward_one_hot(): + _test_forward_one_hot((3,), 3, "int32") + _test_forward_one_hot((3,), 3, "float32") + _test_forward_one_hot((2, 2), 5, "int32") + _test_forward_one_hot((2, 2), 5, "float32") + _test_forward_one_hot((3, 2, 4, 5), 6, "int32") + _test_forward_one_hot((3, 2, 4, 5), 6, "float32") + ####################################################################### # Main @@ -2193,6 +2211,7 @@ def test_placeholder(): test_forward_right_shift() test_forward_left_shift() test_forward_truncatemod() + test_forward_one_hot() # Activations test_forward_sigmoid() From 173677e409a9e1bec964f477bdfb3b6977e042d0 Mon Sep 17 00:00:00 2001 From: Jon Date: Thu, 15 Aug 2019 10:39:41 -0700 Subject: [PATCH 06/19] Fix check --- src/relay/op/tensor/transform.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 737ae2fafbe0..519e606f22ec 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2495,7 +2495,7 @@ bool OneHotRel(const Array& types, CHECK(indices); const auto param = attrs.as(); - CHECK(param->depth != 0); + CHECK(param->depth > 0); Array output_shape(indices->shape); output_shape.push_back(param->depth); From 807414b740dd0b48be4549c4591948e7c1602303 Mon Sep 17 00:00:00 2001 From: Jon Date: Thu, 15 Aug 2019 10:56:58 -0700 Subject: [PATCH 07/19] fix linting issues --- include/tvm/relay/attrs/transform.h | 4 ++-- python/tvm/relay/frontend/tensorflow.py | 6 ++---- python/tvm/relay/op/transform.py | 2 +- src/relay/op/tensor/transform.cc | 2 +- topi/include/topi/transform.h | 4 +++- topi/python/topi/transform.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 79577d1a841b..f5662e446409 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -303,8 +303,8 @@ struct OneHotAttrs : public tvm::AttrsNode { int depth; TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") { - TVM_ATTR_FIELD(depth).set_default(1) - .describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(depth).set_default(1) + .describe("Depth of the one hot dimension."); } }; // struct OneHotAttrs diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 85ecd0e4ae7f..7c08e12289a5 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1214,12 +1214,10 @@ def _impl(inputs, attr, params): def _one_hot(): def _impl(inputs, attr, params): - indices = inputs[0] depth = int(_infer_value(inputs[1], params).asnumpy()[0]) - - one_hot = AttrCvt('one_hot', + one_hot = AttrCvt('one_hot', ignores=['axis', 'TI'], - extras={ 'depth' : depth })([inputs[0]], attr) + extras={'depth' : depth})([inputs[0]], attr) out_dtype = attr.get("T", None) if out_dtype is None: diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 71ef6f7996df..fb8f9e72c647 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -751,7 +751,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): def one_hot(indices, depth): """ - Returns a one-hot tensor where the locations repsented by indices take value 1, + Returns a one-hot tensor where the locations repsented by indices take value 1, other locations take value 0. Final dimension is x depth. Parameters diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 519e606f22ec..3e963b0aa15a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2495,7 +2495,7 @@ bool OneHotRel(const Array& types, CHECK(indices); const auto param = attrs.as(); - CHECK(param->depth > 0); + CHECK_GT(param->depth, 0); Array output_shape(indices->shape); output_shape.push_back(param->depth); diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index a70f9b63bef0..f6bd53dc6e41 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1252,6 +1252,8 @@ inline Tensor ndarray_size(const Tensor& src, other locations take value 0. * \param indices locations to set to 1. * \param depth depth of the one-hot dimension. + * \param name output tensor name. + * \param tag output tensor tag. * \return one-hot tensor. */ inline Tensor one_hot(const Tensor& indices, @@ -1265,7 +1267,7 @@ inline Tensor one_hot(const Tensor& indices, for (size_t i = 0; i < iter_vars.size() - 1; i++) { outer_indices.push_back(iter_vars[i]); } - + auto idx = iter_vars[iter_vars.size() - 1]; return tvm::if_then_else(indices(outer_indices) == idx, 1, 0); }, name, tag); diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 3fec90e29cb6..e3146717a7d2 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -521,7 +521,7 @@ def where(condition, x, y): def one_hot(indices, depth): """ - Returns a one-hot tensor where the locations repsented by indices take value 1, + Returns a one-hot tensor where the locations repsented by indices take value 1, other locations take value 0. Parameters From 79c598a4dcd20c31596e887143f0b5b1afdedb05 Mon Sep 17 00:00:00 2001 From: Jon Date: Thu, 15 Aug 2019 10:58:22 -0700 Subject: [PATCH 08/19] fix documentation --- python/tvm/relay/op/transform.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index fb8f9e72c647..792cde999e2e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -774,8 +774,8 @@ def one_hot(indices, depth): indices = [1., 2., 3.] relay.one_hot(indices, 2) = - [[[1., 0., 0.]], - [[0., 1., 0.]], - [[0., 0., 1.]]] + [[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]] """ return _make.one_hot(indices, depth) From 7531c74b089c54b28b47a32d75854c7d0202705f Mon Sep 17 00:00:00 2001 From: Jon Date: Mon, 19 Aug 2019 10:36:30 -0700 Subject: [PATCH 09/19] Fix documentation --- python/tvm/relay/op/transform.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 792cde999e2e..9d936fba40b9 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -771,11 +771,11 @@ def one_hot(indices, depth): -------- .. code-block:: python - indices = [1., 2., 3.] + indices = [0, 1, 2] - relay.one_hot(indices, 2) = - [[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]] + relay.one_hot(indices, 3) = + [[1, 0, 0], + [0, 1, 0], + [0, 0, 1]] """ return _make.one_hot(indices, depth) From 8e6ac5bf320d9b0b39f33f2662b7522c1f161808 Mon Sep 17 00:00:00 2001 From: Jon Date: Mon, 19 Aug 2019 14:03:55 -0700 Subject: [PATCH 10/19] Add support for on_value, off_value, axis, dtype --- include/tvm/relay/attrs/transform.h | 16 ++++++++++++++-- python/tvm/relay/frontend/tensorflow.py | 16 ++++++---------- python/tvm/relay/op/transform.py | 17 ++++++++++++----- src/relay/op/tensor/transform.cc | 16 ++++++++++++---- topi/include/topi/transform.h | 7 ++++++- topi/src/topi.cc | 6 +++++- 6 files changed, 55 insertions(+), 23 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index f5662e446409..32b1cc7de9f6 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -301,10 +301,22 @@ struct NdarraySizeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in one-hot operator */ struct OneHotAttrs : public tvm::AttrsNode { int depth; + double on_value; + double off_value; + int axis; + DataType dtype; TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") { - TVM_ATTR_FIELD(depth).set_default(1) - .describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(depth).set_default(1) + .describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(on_value).set_default(1) + .describe("On value."); + TVM_ATTR_FIELD(off_value).set_default(0) + .describe("Off value."); + TVM_ATTR_FIELD(axis).set_default(-1) + .describe("Axis to fill."); + TVM_ATTR_FIELD(dtype).set_default(NullValue()) + .describe("Output data type."); } }; // struct OneHotAttrs diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7c08e12289a5..cfdd8fcb050b 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1215,16 +1215,12 @@ def _impl(inputs, attr, params): def _one_hot(): def _impl(inputs, attr, params): depth = int(_infer_value(inputs[1], params).asnumpy()[0]) - one_hot = AttrCvt('one_hot', - ignores=['axis', 'TI'], - extras={'depth' : depth})([inputs[0]], attr) - - out_dtype = attr.get("T", None) - if out_dtype is None: - out_dtype_name = "float32" - else: - out_dtype_name = out_dtype.name - return _op.cast(one_hot, out_dtype_name) + on_value = float(_infer_value(inputs[2], params).asnumpy()[0]) + off_value = float(_infer_value(inputs[3], params).asnumpy()[0]) + return AttrCvt('one_hot', + ignores=['TI'], + extras={'depth' : depth, 'on_value' : on_value, 'off_value' : off_value,\ + 'dtype' : attr['T'].name})([inputs[0]], attr) return _impl # compatible operators that do NOT require any conversion. diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 9d936fba40b9..506b2f487b45 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -749,19 +749,26 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): """ return _make.sequence_mask(data, valid_length, mask_value, axis) -def one_hot(indices, depth): +def one_hot(indices, depth, on_value, off_value, axis, dtype): """ - Returns a one-hot tensor where the locations repsented by indices take value 1, - other locations take value 0. Final dimension is x depth. + Returns a one-hot tensor where the locations repsented by indices take value on_value, + other locations take value off_value. + Final dimension is x depth x . Parameters ---------- indices : relay.Expr - Locations to set to 1. + Locations to set to on_value. depth : int Depth of the one-hot dimension. + on_value : float + Value to fill at indices. + + off_value : float + Value to fill at all other positions besides indices. + Returns ------- ret : relay.Expr @@ -778,4 +785,4 @@ def one_hot(indices, depth): [0, 1, 0], [0, 0, 1]] """ - return _make.one_hot(indices, depth) + return _make.one_hot(indices, depth, on_value, off_value, axis, dtype) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3e963b0aa15a..4759fe224baa 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2500,7 +2500,7 @@ bool OneHotRel(const Array& types, Array output_shape(indices->shape); output_shape.push_back(param->depth); - reporter->Assign(types[1], TensorTypeNode::make(output_shape, indices->dtype)); + reporter->Assign(types[1], TensorTypeNode::make(output_shape, param->dtype)); return true; } @@ -2510,13 +2510,21 @@ Array OneHotCompute(const Attrs& attrs, const Target& target) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array{ topi::one_hot(inputs[0], param->depth) }; + return Array{ topi::one_hot(inputs[0], param->depth, (float)param->on_value, (float)param->off_value, param->axis, param->dtype) }; } Expr MakeOneHot(Expr indices, - int depth) { + int depth, + double on_value, + double off_value, + int axis, + DataType dtype) { auto attrs = make_node(); attrs->depth = std::move(depth); + attrs->on_value = on_value; + attrs->off_value = off_value; + attrs->axis = axis; + attrs->dtype = dtype; static const Op& op = Op::Get("one_hot"); return CallNode::make(op, {indices}, Attrs(attrs), {}); } @@ -2533,7 +2541,7 @@ RELAY_REGISTER_OP("one_hot") **depth** Depth of the one-hot dimension.)code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.OneHotAttrs") .set_num_inputs(1) -.add_argument("indices", "Tensor", "Locations to set to 1.") +.add_argument("indices", "Tensor", "Locations to set to on_value.") .set_support_level(10) .add_type_rel("OneHot", OneHotRel) .set_attr("FTVMCompute", OneHotCompute) diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index f6bd53dc6e41..1d1c3adf737b 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1258,6 +1258,10 @@ inline Tensor ndarray_size(const Tensor& src, */ inline Tensor one_hot(const Tensor& indices, int depth, + float on_value, + float off_value, + int axis, + const Type& dtype, const std::string name = "T_one_hot", const std::string tag = kInjective) { Array out_shape = indices->shape; @@ -1269,7 +1273,8 @@ inline Tensor one_hot(const Tensor& indices, } auto idx = iter_vars[iter_vars.size() - 1]; - return tvm::if_then_else(indices(outer_indices) == idx, 1, 0); + auto ret = ir::Select::make(indices(outer_indices) == idx, on_value, off_value); + return tvm::cast(dtype, ret); }, name, tag); } diff --git a/topi/src/topi.cc b/topi/src/topi.cc index a3cd63f0cdbb..f3663bbeb1b0 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -419,7 +419,11 @@ TVM_REGISTER_GLOBAL("topi.strided_slice") TVM_REGISTER_GLOBAL("topi.one_hot") .set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = one_hot(args[0], args[1]); + double on_value = args[2]; + double off_value = args[3]; + int axis = args[4]; + DataType dtype = args[5]; + *rv = one_hot(args[0], args[1], (float)on_value, (float)off_value, axis, dtype); }); /* Ops from nn/upsampling.h */ From b91fcc7d103f8d49e564ab9bce436c513d5cc0b1 Mon Sep 17 00:00:00 2001 From: Jon Date: Mon, 19 Aug 2019 15:58:30 -0700 Subject: [PATCH 11/19] Add full support for axis --- src/relay/op/tensor/transform.cc | 15 +++++++++--- topi/include/topi/transform.h | 39 +++++++++++++++++++++++--------- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4759fe224baa..3d8a06e5629a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2497,10 +2497,19 @@ bool OneHotRel(const Array& types, const auto param = attrs.as(); CHECK_GT(param->depth, 0); - Array output_shape(indices->shape); - output_shape.push_back(param->depth); + Array oshape; + int ndim = indices->shape.size() + 1; + int indices_index = 0; + int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis; + for (int i = 0; i < ndim; i++) { + if (i == true_axis) { + oshape.push_back(Integer(param->depth)); + } else { + oshape.push_back(indices->shape[indices_index++]); + } + } - reporter->Assign(types[1], TensorTypeNode::make(output_shape, param->dtype)); + reporter->Assign(types[1], TensorTypeNode::make(oshape, param->dtype)); return true; } diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 1d1c3adf737b..ead230cffa1a 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1248,10 +1248,13 @@ inline Tensor ndarray_size(const Tensor& src, } /*! - * \brief Returns a one-hot tensor where the locations repsented by indices take value 1, - other locations take value 0. - * \param indices locations to set to 1. + * \brief Returns a one-hot tensor where the locations repsented by indices take value on_value, + other locations take value off_value. + * \param indices locations to set to on_value. * \param depth depth of the one-hot dimension. + * \param on_value value that locations represented by indices take on. + * \param off_value value that other locations take on. + * \param axis axis of one-hot dimension. * \param name output tensor name. * \param tag output tensor tag. * \return one-hot tensor. @@ -1264,16 +1267,30 @@ inline Tensor one_hot(const Tensor& indices, const Type& dtype, const std::string name = "T_one_hot", const std::string tag = kInjective) { - Array out_shape = indices->shape; - out_shape.push_back(depth); - return compute(out_shape, [&](const Array& iter_vars) { - Array outer_indices; - for (size_t i = 0; i < iter_vars.size() - 1; i++) { - outer_indices.push_back(iter_vars[i]); + Array oshape; + int ndim = indices->shape.size() + 1; + int indices_index = 0; + int true_axis = (axis == -1) ? indices->shape.size() : axis; + for (int i = 0; i < ndim; i++) { + if (i == true_axis) { + oshape.push_back(Integer(depth)); + } else { + oshape.push_back(indices->shape[indices_index++]); + } + } + + return compute(oshape, [&](const Array& iter_vars) { + Array indices_indices; + for (size_t i = 0; i < iter_vars.size(); i++) { + if (i == axis) { + continue; + } + + indices_indices.push_back(iter_vars[i]); } - auto idx = iter_vars[iter_vars.size() - 1]; - auto ret = ir::Select::make(indices(outer_indices) == idx, on_value, off_value); + auto idx = iter_vars[axis]; + auto ret = ir::Select::make(indices(indices_indices) == idx, on_value, off_value); return tvm::cast(dtype, ret); }, name, tag); } From 1823bf1c0e9260c31559b3c14e2acd3cb2017646 Mon Sep 17 00:00:00 2001 From: Jon Date: Mon, 19 Aug 2019 16:11:39 -0700 Subject: [PATCH 12/19] Fix compute and update test_forward --- .../frontend/tensorflow/test_forward.py | 220 +++++++++--------- topi/include/topi/transform.h | 4 +- 2 files changed, 112 insertions(+), 112 deletions(-) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 37088684d606..9ac2bdb8b181 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2161,20 +2161,20 @@ def test_placeholder(): ####################################################################### # OneHot # ---------------------- -def _test_forward_one_hot(indices_shape, depth, out_dtype): +def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype): inp_array1 = np.random.randint(0, 5, size=indices_shape) with tf.Graph().as_default(): in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype) - out = tf.one_hot(in1, depth, dtype=out_dtype) + out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype) compare_tf_with_tvm(inp_array1, in1.name, out.name) def test_forward_one_hot(): - _test_forward_one_hot((3,), 3, "int32") - _test_forward_one_hot((3,), 3, "float32") - _test_forward_one_hot((2, 2), 5, "int32") - _test_forward_one_hot((2, 2), 5, "float32") - _test_forward_one_hot((3, 2, 4, 5), 6, "int32") - _test_forward_one_hot((3, 2, 4, 5), 6, "float32") + _test_forward_one_hot((3,), 3, 1, 0, -1, "int32") + _test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32") + _test_forward_one_hot((2, 2), 5, 2, -2, 0, "int32") + _test_forward_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32") + _test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32") + _test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") ####################################################################### @@ -2183,110 +2183,110 @@ def test_forward_one_hot(): if __name__ == '__main__': # Transforms - test_forward_transpose() - test_forward_reshape() - test_forward_depthtospace() - test_forward_spacetodepth() - test_forward_squeeze() - test_forward_pack() - test_forward_size() - test_forward_broadcast_to() - test_forward_fill() - test_forward_crop() - test_forward_resize() - test_forward_crop_and_resize() - test_forward_pad() - test_forward_unpack() - test_forward_gather() - test_forward_gather_nd() - test_forward_stridedslice() - test_forward_split() - test_forward_unstack() - test_forward_tile() - test_forward_top_k_v2() - test_forward_clip_by_value() - test_forward_maximum() - test_forward_minimum() - test_forward_range() - test_forward_right_shift() - test_forward_left_shift() - test_forward_truncatemod() + # test_forward_transpose() + # test_forward_reshape() + # test_forward_depthtospace() + # test_forward_spacetodepth() + # test_forward_squeeze() + # test_forward_pack() + # test_forward_size() + # test_forward_broadcast_to() + # test_forward_fill() + # test_forward_crop() + # test_forward_resize() + # test_forward_crop_and_resize() + # test_forward_pad() + # test_forward_unpack() + # test_forward_gather() + # test_forward_gather_nd() + # test_forward_stridedslice() + # test_forward_split() + # test_forward_unstack() + # test_forward_tile() + # test_forward_top_k_v2() + # test_forward_clip_by_value() + # test_forward_maximum() + # test_forward_minimum() + # test_forward_range() + # test_forward_right_shift() + # test_forward_left_shift() + # test_forward_truncatemod() test_forward_one_hot() # Activations - test_forward_sigmoid() - test_forward_relu() - test_forward_leaky_relu() - test_forward_elu() - test_forward_selu() - test_forward_tanh() - - # Tensor - test_forward_round() - test_forward_reverse_v2() - test_forward_pow_exp() - test_forward_sign() - test_forward_log() - test_forward_log1p() - test_forward_cos() - test_forward_sin() - test_forward_negative() - test_forward_divide() - test_forward_abs() - test_forward_softplus() - test_forward_sqrt() - test_forward_rsqrt() - test_forward_expand_dims() - test_forward_square() - test_forward_softmax() - test_forward_log_softmax() - test_forward_bias_add() - test_forward_zeros_like() - - # Reductions - test_forward_argminmax() - test_forward_reduce() - test_forward_mean() - test_forward_reduce_prod() - test_forward_reduce_all() - test_forward_reduce_max() - test_forward_reduce_min() - - # General - test_forward_multi_input() - test_forward_multi_output() - test_forward_variable() - test_placeholder() - - # NN - test_forward_convolution() - test_forward_pooling() - test_forward_concat_v2() - test_forward_lrn() - test_forward_l2_normalize() - test_forward_space_to_batch_nd() - test_forward_batch_to_space_nd() - - # End to End - test_forward_inception_v3() - test_forward_inception_v1() - test_forward_mobilenet() - test_forward_resnetv2() - test_forward_placeholder() - test_forward_ptb() - - # RNN - test_forward_lstm() - - # Elementwise - test_forward_ceil() - test_forward_floor() - - # Relational ops - test_forward_rel_ops() - test_forward_logical() - test_forward_where() - test_forward_matmul() - test_forward_batch_matmul() + # test_forward_sigmoid() + # test_forward_relu() + # test_forward_leaky_relu() + # test_forward_elu() + # test_forward_selu() + # test_forward_tanh() + + # # Tensor + # test_forward_round() + # test_forward_reverse_v2() + # test_forward_pow_exp() + # test_forward_sign() + # test_forward_log() + # test_forward_log1p() + # test_forward_cos() + # test_forward_sin() + # test_forward_negative() + # test_forward_divide() + # test_forward_abs() + # test_forward_softplus() + # test_forward_sqrt() + # test_forward_rsqrt() + # test_forward_expand_dims() + # test_forward_square() + # test_forward_softmax() + # test_forward_log_softmax() + # test_forward_bias_add() + # test_forward_zeros_like() + + # # Reductions + # test_forward_argminmax() + # test_forward_reduce() + # test_forward_mean() + # test_forward_reduce_prod() + # test_forward_reduce_all() + # test_forward_reduce_max() + # test_forward_reduce_min() + + # # General + # test_forward_multi_input() + # test_forward_multi_output() + # test_forward_variable() + # test_placeholder() + + # # NN + # test_forward_convolution() + # test_forward_pooling() + # test_forward_concat_v2() + # test_forward_lrn() + # test_forward_l2_normalize() + # test_forward_space_to_batch_nd() + # test_forward_batch_to_space_nd() + + # # End to End + # test_forward_inception_v3() + # test_forward_inception_v1() + # test_forward_mobilenet() + # test_forward_resnetv2() + # test_forward_placeholder() + # test_forward_ptb() + + # # RNN + # test_forward_lstm() + + # # Elementwise + # test_forward_ceil() + # test_forward_floor() + + # # Relational ops + # test_forward_rel_ops() + # test_forward_logical() + # test_forward_where() + # test_forward_matmul() + # test_forward_batch_matmul() # TODO missing tests: rank diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index ead230cffa1a..9eeb33864f5c 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1282,14 +1282,14 @@ inline Tensor one_hot(const Tensor& indices, return compute(oshape, [&](const Array& iter_vars) { Array indices_indices; for (size_t i = 0; i < iter_vars.size(); i++) { - if (i == axis) { + if (i == true_axis) { continue; } indices_indices.push_back(iter_vars[i]); } - auto idx = iter_vars[axis]; + auto idx = iter_vars[true_axis]; auto ret = ir::Select::make(indices(indices_indices) == idx, on_value, off_value); return tvm::cast(dtype, ret); }, name, tag); From 67363e54ac42fdb862c435594d4c9f0ddd472e1c Mon Sep 17 00:00:00 2001 From: Jon Date: Mon, 19 Aug 2019 16:59:26 -0700 Subject: [PATCH 13/19] Move on_value and off_value to inputs --- include/tvm/relay/attrs/transform.h | 6 - python/tvm/relay/frontend/tensorflow.py | 12 +- python/tvm/relay/op/transform.py | 10 +- src/relay/op/tensor/transform.cc | 30 ++- .../frontend/tensorflow/test_forward.py | 204 +++++++++--------- topi/include/topi/transform.h | 13 +- 6 files changed, 140 insertions(+), 135 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 32b1cc7de9f6..52656872ad10 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -301,18 +301,12 @@ struct NdarraySizeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in one-hot operator */ struct OneHotAttrs : public tvm::AttrsNode { int depth; - double on_value; - double off_value; int axis; DataType dtype; TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") { TVM_ATTR_FIELD(depth).set_default(1) .describe("Depth of the one hot dimension."); - TVM_ATTR_FIELD(on_value).set_default(1) - .describe("On value."); - TVM_ATTR_FIELD(off_value).set_default(0) - .describe("Off value."); TVM_ATTR_FIELD(axis).set_default(-1) .describe("Axis to fill."); TVM_ATTR_FIELD(dtype).set_default(NullValue()) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index cfdd8fcb050b..7ef3b96ed4f1 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1214,13 +1214,15 @@ def _impl(inputs, attr, params): def _one_hot(): def _impl(inputs, attr, params): - depth = int(_infer_value(inputs[1], params).asnumpy()[0]) - on_value = float(_infer_value(inputs[2], params).asnumpy()[0]) - off_value = float(_infer_value(inputs[3], params).asnumpy()[0]) + depth = int(_get_num_param(params, inputs[1])) + dtype = attr['T'].name + + on_value = _get_num_param(params, inputs[2]) + off_value = _get_num_param(params, inputs[3]) + new_inputs = [inputs[0], tvm.relay.const(on_value, dtype), tvm.relay.const(off_value, dtype)] return AttrCvt('one_hot', ignores=['TI'], - extras={'depth' : depth, 'on_value' : on_value, 'off_value' : off_value,\ - 'dtype' : attr['T'].name})([inputs[0]], attr) + extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr) return _impl # compatible operators that do NOT require any conversion. diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 506b2f487b45..6fc0aa5a18ed 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -749,7 +749,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): """ return _make.sequence_mask(data, valid_length, mask_value, axis) -def one_hot(indices, depth, on_value, off_value, axis, dtype): +def one_hot(indices, on_value, off_value, depth, axis, dtype): """ Returns a one-hot tensor where the locations repsented by indices take value on_value, other locations take value off_value. @@ -760,15 +760,15 @@ def one_hot(indices, depth, on_value, off_value, axis, dtype): indices : relay.Expr Locations to set to on_value. - depth : int - Depth of the one-hot dimension. - on_value : float Value to fill at indices. off_value : float Value to fill at all other positions besides indices. + depth : int + Depth of the one-hot dimension. + Returns ------- ret : relay.Expr @@ -785,4 +785,4 @@ def one_hot(indices, depth, on_value, off_value, axis, dtype): [0, 1, 0], [0, 0, 1]] """ - return _make.one_hot(indices, depth, on_value, off_value, axis, dtype) + return _make.one_hot(indices, on_value, off_value, depth, axis, dtype) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3d8a06e5629a..cf81ab10e858 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2489,8 +2489,8 @@ bool OneHotRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // `types` contains: [indices, result] - CHECK_EQ(types.size(), 2); + // `types` contains: [indices, on_value, off_value, result] + CHECK_EQ(types.size(), 4); const auto* indices = types[0].as(); CHECK(indices); @@ -2509,7 +2509,7 @@ bool OneHotRel(const Array& types, } } - reporter->Assign(types[1], TensorTypeNode::make(oshape, param->dtype)); + reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype)); return true; } @@ -2519,23 +2519,21 @@ Array OneHotCompute(const Attrs& attrs, const Target& target) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array{ topi::one_hot(inputs[0], param->depth, (float)param->on_value, (float)param->off_value, param->axis, param->dtype) }; + return Array{ topi::one_hot(inputs[0], inputs[1](), inputs[2](), param->depth, param->axis, param->dtype) }; } Expr MakeOneHot(Expr indices, + Expr on_value, + Expr off_value, int depth, - double on_value, - double off_value, int axis, DataType dtype) { auto attrs = make_node(); attrs->depth = std::move(depth); - attrs->on_value = on_value; - attrs->off_value = off_value; attrs->axis = axis; attrs->dtype = dtype; static const Op& op = Op::Get("one_hot"); - return CallNode::make(op, {indices}, Attrs(attrs), {}); + return CallNode::make(op, {indices, on_value, off_value}, Attrs(attrs), {}); } TVM_REGISTER_API("relay.op._make.one_hot") @@ -2547,10 +2545,20 @@ RELAY_REGISTER_OP("one_hot") **indices** Locations to set to 1. - **depth** Depth of the one-hot dimension.)code" TVM_ADD_FILELINE) + **on_value** Value to fill at indices. + + **off_value** Value to fill at all other positions besides indices. + + **depth** Depth of the one-hot dimension. + + **axis** Axis of one-hot dimension. + + **dtype**)code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.OneHotAttrs") -.set_num_inputs(1) +.set_num_inputs(3) .add_argument("indices", "Tensor", "Locations to set to on_value.") +.add_argument("on_value", "Expr", "Value to fill at indices.") +.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") .set_support_level(10) .add_type_rel("OneHot", OneHotRel) .set_attr("FTVMCompute", OneHotCompute) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 9ac2bdb8b181..cfa7e3cfd29d 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2183,110 +2183,110 @@ def test_forward_one_hot(): if __name__ == '__main__': # Transforms - # test_forward_transpose() - # test_forward_reshape() - # test_forward_depthtospace() - # test_forward_spacetodepth() - # test_forward_squeeze() - # test_forward_pack() - # test_forward_size() - # test_forward_broadcast_to() - # test_forward_fill() - # test_forward_crop() - # test_forward_resize() - # test_forward_crop_and_resize() - # test_forward_pad() - # test_forward_unpack() - # test_forward_gather() - # test_forward_gather_nd() - # test_forward_stridedslice() - # test_forward_split() - # test_forward_unstack() - # test_forward_tile() - # test_forward_top_k_v2() - # test_forward_clip_by_value() - # test_forward_maximum() - # test_forward_minimum() - # test_forward_range() - # test_forward_right_shift() - # test_forward_left_shift() - # test_forward_truncatemod() + test_forward_transpose() + test_forward_reshape() + test_forward_depthtospace() + test_forward_spacetodepth() + test_forward_squeeze() + test_forward_pack() + test_forward_size() + test_forward_broadcast_to() + test_forward_fill() + test_forward_crop() + test_forward_resize() + test_forward_crop_and_resize() + test_forward_pad() + test_forward_unpack() + test_forward_gather() + test_forward_gather_nd() + test_forward_stridedslice() + test_forward_split() + test_forward_unstack() + test_forward_tile() + test_forward_top_k_v2() + test_forward_clip_by_value() + test_forward_maximum() + test_forward_minimum() + test_forward_range() + test_forward_right_shift() + test_forward_left_shift() + test_forward_truncatemod() test_forward_one_hot() # Activations - # test_forward_sigmoid() - # test_forward_relu() - # test_forward_leaky_relu() - # test_forward_elu() - # test_forward_selu() - # test_forward_tanh() - - # # Tensor - # test_forward_round() - # test_forward_reverse_v2() - # test_forward_pow_exp() - # test_forward_sign() - # test_forward_log() - # test_forward_log1p() - # test_forward_cos() - # test_forward_sin() - # test_forward_negative() - # test_forward_divide() - # test_forward_abs() - # test_forward_softplus() - # test_forward_sqrt() - # test_forward_rsqrt() - # test_forward_expand_dims() - # test_forward_square() - # test_forward_softmax() - # test_forward_log_softmax() - # test_forward_bias_add() - # test_forward_zeros_like() - - # # Reductions - # test_forward_argminmax() - # test_forward_reduce() - # test_forward_mean() - # test_forward_reduce_prod() - # test_forward_reduce_all() - # test_forward_reduce_max() - # test_forward_reduce_min() - - # # General - # test_forward_multi_input() - # test_forward_multi_output() - # test_forward_variable() - # test_placeholder() - - # # NN - # test_forward_convolution() - # test_forward_pooling() - # test_forward_concat_v2() - # test_forward_lrn() - # test_forward_l2_normalize() - # test_forward_space_to_batch_nd() - # test_forward_batch_to_space_nd() - - # # End to End - # test_forward_inception_v3() - # test_forward_inception_v1() - # test_forward_mobilenet() - # test_forward_resnetv2() - # test_forward_placeholder() - # test_forward_ptb() - - # # RNN - # test_forward_lstm() - - # # Elementwise - # test_forward_ceil() - # test_forward_floor() - - # # Relational ops - # test_forward_rel_ops() - # test_forward_logical() - # test_forward_where() - # test_forward_matmul() - # test_forward_batch_matmul() + test_forward_sigmoid() + test_forward_relu() + test_forward_leaky_relu() + test_forward_elu() + test_forward_selu() + test_forward_tanh() + + # Tensor + test_forward_round() + test_forward_reverse_v2() + test_forward_pow_exp() + test_forward_sign() + test_forward_log() + test_forward_log1p() + test_forward_cos() + test_forward_sin() + test_forward_negative() + test_forward_divide() + test_forward_abs() + test_forward_softplus() + test_forward_sqrt() + test_forward_rsqrt() + test_forward_expand_dims() + test_forward_square() + test_forward_softmax() + test_forward_log_softmax() + test_forward_bias_add() + test_forward_zeros_like() + + # Reductions + test_forward_argminmax() + test_forward_reduce() + test_forward_mean() + test_forward_reduce_prod() + test_forward_reduce_all() + test_forward_reduce_max() + test_forward_reduce_min() + + # General + test_forward_multi_input() + test_forward_multi_output() + test_forward_variable() + test_placeholder() + + # NN + test_forward_convolution() + test_forward_pooling() + test_forward_concat_v2() + test_forward_lrn() + test_forward_l2_normalize() + test_forward_space_to_batch_nd() + test_forward_batch_to_space_nd() + + # End to End + test_forward_inception_v3() + test_forward_inception_v1() + test_forward_mobilenet() + test_forward_resnetv2() + test_forward_placeholder() + test_forward_ptb() + + # RNN + test_forward_lstm() + + # Elementwise + test_forward_ceil() + test_forward_floor() + + # Relational ops + test_forward_rel_ops() + test_forward_logical() + test_forward_where() + test_forward_matmul() + test_forward_batch_matmul() # TODO missing tests: rank diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 9eeb33864f5c..873fe0ad0e39 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1251,18 +1251,18 @@ inline Tensor ndarray_size(const Tensor& src, * \brief Returns a one-hot tensor where the locations repsented by indices take value on_value, other locations take value off_value. * \param indices locations to set to on_value. - * \param depth depth of the one-hot dimension. * \param on_value value that locations represented by indices take on. * \param off_value value that other locations take on. + * \param depth depth of the one-hot dimension. * \param axis axis of one-hot dimension. * \param name output tensor name. * \param tag output tensor tag. * \return one-hot tensor. */ inline Tensor one_hot(const Tensor& indices, + const Expr on_value, + const Expr off_value, int depth, - float on_value, - float off_value, int axis, const Type& dtype, const std::string name = "T_one_hot", @@ -1279,10 +1279,12 @@ inline Tensor one_hot(const Tensor& indices, } } + Expr on_value_cast = cast(dtype, on_value); + Expr off_value_cast = cast(dtype, off_value); return compute(oshape, [&](const Array& iter_vars) { Array indices_indices; for (size_t i = 0; i < iter_vars.size(); i++) { - if (i == true_axis) { + if ((int)i == true_axis) { continue; } @@ -1290,8 +1292,7 @@ inline Tensor one_hot(const Tensor& indices, } auto idx = iter_vars[true_axis]; - auto ret = ir::Select::make(indices(indices_indices) == idx, on_value, off_value); - return tvm::cast(dtype, ret); + return ir::Select::make(indices(indices_indices) == idx, on_value_cast, off_value_cast); }, name, tag); } From ffd2fb8ded1edfa9c1a97880fd2317cd1e61c330 Mon Sep 17 00:00:00 2001 From: Jon Date: Tue, 20 Aug 2019 10:49:44 -0700 Subject: [PATCH 14/19] Add topi test --- python/tvm/relay/op/transform.py | 10 ++- src/relay/op/tensor/transform.cc | 2 +- topi/include/topi/transform.h | 3 +- topi/python/topi/testing/__init__.py | 1 + topi/python/topi/testing/one_hot.py | 79 ++++++++++++++++++++++++ topi/python/topi/transform.py | 36 +++++++++-- topi/src/topi.cc | 5 +- topi/tests/python/test_topi_transform.py | 18 +++--- 8 files changed, 134 insertions(+), 20 deletions(-) create mode 100644 topi/python/topi/testing/one_hot.py diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 6fc0aa5a18ed..b43608ad5cf2 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -760,15 +760,21 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): indices : relay.Expr Locations to set to on_value. - on_value : float + on_value : relay.Expr Value to fill at indices. - off_value : float + off_value : relay.Expr Value to fill at all other positions besides indices. depth : int Depth of the one-hot dimension. + axis : int + Axis to fill. + + dtype : str + Data type of the output tensor. + Returns ------- ret : relay.Expr diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index cf81ab10e858..721be9f8149b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2551,7 +2551,7 @@ RELAY_REGISTER_OP("one_hot") **depth** Depth of the one-hot dimension. - **axis** Axis of one-hot dimension. + **axis** Axis to fill. **dtype**)code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.OneHotAttrs") diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 873fe0ad0e39..4994f39bd527 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1254,7 +1254,8 @@ inline Tensor ndarray_size(const Tensor& src, * \param on_value value that locations represented by indices take on. * \param off_value value that other locations take on. * \param depth depth of the one-hot dimension. - * \param axis axis of one-hot dimension. + * \param axis axis to fill. + * \param dtype data type of the output tensor. * \param name output tensor name. * \param tag output tensor tag. * \return one-hot tensor. diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 57a9c264edc4..d607c28dccdb 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -25,3 +25,4 @@ from .slice_axis_python import slice_axis_python from .sequence_mask_python import sequence_mask from .pool_grad_python import pool_grad_nchw +from .one_hot import one_hot diff --git a/topi/python/topi/testing/one_hot.py b/topi/python/topi/testing/one_hot.py new file mode 100644 index 000000000000..a366e78eab89 --- /dev/null +++ b/topi/python/topi/testing/one_hot.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""OneHot in python""" +import numpy as np + +def one_hot(indices, on_value, off_value, depth, axis, dtype): + """one_hot operator implemented in numpy. + + Returns a one-hot tensor where the locations repsented by indices take value on_value, + other locations take value off_value. + Final dimension is x depth x . + + Parameters + ---------- + indices : numpy.ndarray + Locations to set to on_value. + + on_value : int/float + Value to fill at indices. + + off_value : int/float + Value to fill at all other positions besides indices. + + depth : int + Depth of the one-hot dimension. + + axis : int + Axis to fill. + + dtype : str + Data type of the output tensor. + + Returns + ------- + ret : relay.Expr + The one-hot tensor. + """ + oshape = [] + true_axis = len(indices.shape) if axis == -1 else axis + ndim = len(indices.shape) + 1 + indices_index = 0 + for i in range(0, ndim): + if i == true_axis: + oshape.append(depth) + else: + oshape.append(indices.shape[indices_index]) + indices_index += 1 + + out = np.empty(oshape) + output_indices = [index for index in np.ndindex(out.shape)] + for output_index in output_indices: + indices_indices = [] + for i in range(0, len(output_index)): + if i == true_axis: + continue + indices_indices.append(output_index[i]) + + index = output_index[true_axis] + if indices[tuple(indices_indices)] == index: + out[output_index] = on_value + else: + out[output_index] = off_value + + return out.astype(dtype) \ No newline at end of file diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index e3146717a7d2..ee89091d5070 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -519,22 +519,46 @@ def where(condition, x, y): """ return cpp.where(condition, x, y) -def one_hot(indices, depth): +def one_hot(indices, on_value, off_value, depth, axis, dtype): """ - Returns a one-hot tensor where the locations repsented by indices take value 1, - other locations take value 0. + Returns a one-hot tensor where the locations repsented by indices take value on_value, + other locations take value off_value. + Final dimension is x depth x . Parameters ---------- - indices : relay.Expr - Locations to set to 1. + indices : tvm.Tensor + Locations to set to on_value. + + on_value : tvm.Tensor + Value to fill at indices. + + off_value : tvm.Tensor + Value to fill at all other positions besides indices. depth : int Depth of the one-hot dimension. + axis : int + Axis to fill. + + dtype : relay.DataType + Data type of the output tensor. + Returns ------- ret : relay.Expr The one-hot tensor. + + Examples + -------- + .. code-block:: python + + indices = [0, 1, 2] + + relay.one_hot(indices, 3) = + [[1, 0, 0], + [0, 1, 0], + [0, 0, 1]] """ - return cpp.one_hot(indices, depth) + return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index f3663bbeb1b0..7e47b62a8af4 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -419,11 +419,10 @@ TVM_REGISTER_GLOBAL("topi.strided_slice") TVM_REGISTER_GLOBAL("topi.one_hot") .set_body([](TVMArgs args, TVMRetValue *rv) { - double on_value = args[2]; - double off_value = args[3]; + int depth = args[3]; int axis = args[4]; DataType dtype = args[5]; - *rv = one_hot(args[0], args[1], (float)on_value, (float)off_value, axis, dtype); + *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype); }); /* Ops from nn/upsampling.h */ diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 35e436bba73f..5c0ae133e7b7 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -473,9 +473,11 @@ def check_device(device): for device in get_all_backend(): check_device(device) -def verify_one_hot(indices_shape, depth): +def verify_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype): indices = tvm.placeholder(shape=indices_shape, name="indices", dtype="int32") - one_hot_result = topi.transform.one_hot(indices, depth) + on_value_const = tvm.const(on_value, out_dtype) + off_value_const = tvm.const(off_value, out_dtype) + one_hot_result = topi.transform.one_hot(indices, on_value_const, off_value_const, depth, axis, out_dtype) def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: @@ -486,7 +488,7 @@ def check_device(device): s = topi.generic.schedule_injective(one_hot_result) fn = tvm.build(s, [indices, one_hot_result], device, name="one_hot") indices_npy = np.random.randint(0, depth, size=indices_shape).astype(indices.dtype) - out_npy = np.eye(depth)[indices_npy] + out_npy = topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, out_dtype) indices_nd = tvm.nd.array(indices_npy, ctx) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(one_hot_result.dtype), ctx) fn(indices_nd, out_nd) @@ -794,10 +796,12 @@ def check_device(device): check_device(backend) def test_one_hot(): - verify_one_hot((3,), 3) - verify_one_hot((4,), 3) - verify_one_hot((2, 2), 5) - verify_one_hot((3, 2, 4, 5), 6) + verify_one_hot((3,), 3, 1, 0, -1, "int32") + verify_one_hot((3,), 3, 1.0, 0.0, -1, "float32") + verify_one_hot((2, 2), 5, 2, -2, 0, "int32") + verify_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32") + verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32") + verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") if __name__ == "__main__": test_strided_slice() From 66366ecd3bcd7446678cb61e601046fbf2bc2483 Mon Sep 17 00:00:00 2001 From: Jon Date: Tue, 20 Aug 2019 11:07:01 -0700 Subject: [PATCH 15/19] Update tests --- tests/python/relay/test_op_level10.py | 58 ++++++++++++++++++++---- topi/tests/python/test_topi_transform.py | 10 ++-- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index f3520f3650a3..46dc25a011be 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -296,14 +296,54 @@ def _verify(data_shape, mask_value, axis, dtype, itype): _verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64') _verify((5, 8, 3), 0.1, 1, 'float64', 'float32') +def test_one_hot(): + def _get_oshape(indices_shape, depth, axis): + oshape = [] + true_axis = len(indices_shape) if axis == -1 else axis + ndim = len(indices_shape) + 1 + indices_index = 0 + for i in range(0, ndim): + if i == true_axis: + oshape.append(depth) + else: + oshape.append(indices_shape[indices_index]) + indices_index += 1 + + return oshape + + def _verify(indices_shape, depth, on_value, off_value, axis, dtype): + indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) + on_value_const = relay.const(on_value) + off_value_const = relay.const(off_value) + out = relay.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype) + checked = run_infer_type(out) + assert checked.checked_type == relay.ty.TensorType(_get_oshape(indices_shape, depth, axis), dtype) + func = relay.Function([indices], out) + indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") + out_np = topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + out_relay = intrp.evaluate(func)(indices_np) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) + + _verify((3,), 3, 1, 0, -1, "int32") + _verify((3,), 3, 1.0, 0.0, -1, "float32") + _verify((2, 2), 5, 2, -2, 0, "int32") + _verify((2, 2), 5, 0.5, -0.5, 1, "float32") + _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") + _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + if __name__ == "__main__": - test_adaptive_pool2d() - test_collapse_sum_like() - test_broadcast_to_like() - test_slice_like() - test_reverse_reshape() - test_batch_matmul() - test_shape_of() - test_sequence_mask() - test_ndarray_size() + # test_adaptive_pool2d() + # test_collapse_sum_like() + # test_broadcast_to_like() + # test_slice_like() + # test_reverse_reshape() + # test_batch_matmul() + # test_shape_of() + # test_sequence_mask() + # test_ndarray_size() + test_one_hot() diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 5c0ae133e7b7..b1aa20ea07df 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -473,11 +473,11 @@ def check_device(device): for device in get_all_backend(): check_device(device) -def verify_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype): +def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype): indices = tvm.placeholder(shape=indices_shape, name="indices", dtype="int32") - on_value_const = tvm.const(on_value, out_dtype) - off_value_const = tvm.const(off_value, out_dtype) - one_hot_result = topi.transform.one_hot(indices, on_value_const, off_value_const, depth, axis, out_dtype) + on_value_const = tvm.const(on_value, dtype) + off_value_const = tvm.const(off_value, dtype) + one_hot_result = topi.transform.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype) def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: @@ -488,7 +488,7 @@ def check_device(device): s = topi.generic.schedule_injective(one_hot_result) fn = tvm.build(s, [indices, one_hot_result], device, name="one_hot") indices_npy = np.random.randint(0, depth, size=indices_shape).astype(indices.dtype) - out_npy = topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, out_dtype) + out_npy = topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, dtype) indices_nd = tvm.nd.array(indices_npy, ctx) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(one_hot_result.dtype), ctx) fn(indices_nd, out_nd) From 9fc1bb7f850793c2834d9c14c024e67cce0d0e5e Mon Sep 17 00:00:00 2001 From: Jon Date: Tue, 20 Aug 2019 11:08:59 -0700 Subject: [PATCH 16/19] Update docs --- docs/api/python/topi.rst | 2 ++ docs/langref/relay_op.rst | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 8f59e08c0797..123c1d0c08bb 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -104,6 +104,7 @@ List of operators topi.argsort topi.topk topi.sequence_mask + topi.one_hot List of schedules @@ -173,6 +174,7 @@ topi .. autofunction:: topi.argsort .. autofunction:: topi.topk .. autofunction:: topi.sequence_mask +.. autofunction:: topi.one_hot topi.nn ~~~~~~~ diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 6950ecceee05..4fad352ecd6f 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -200,6 +200,7 @@ This level support backpropagation of broadcast operators. It is temporary. tvm.relay.nn.batch_matmul tvm.relay.contrib.adaptive_max_pool2d tvm.relay.contrib.adaptive_avg_pool2d + tvm.relay.one_hot **Level 11: Dialect Operators** @@ -350,6 +351,7 @@ Level 10 Definitions .. autofunction:: tvm.relay.nn.batch_matmul .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d +.. autofunction:: tvm.relay.one_hot Level 11 Definitions From 93d59661c66ef637e9f1b34e951fce934fc6e2c5 Mon Sep 17 00:00:00 2001 From: Jon Date: Tue, 20 Aug 2019 11:26:04 -0700 Subject: [PATCH 17/19] Fix style --- python/tvm/relay/frontend/tensorflow.py | 6 ++++-- python/tvm/relay/op/transform.py | 2 +- src/relay/op/tensor/transform.cc | 9 ++++++++- topi/include/topi/transform.h | 2 +- topi/python/topi/testing/one_hot.py | 8 ++++---- topi/python/topi/transform.py | 2 +- 6 files changed, 19 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7ef3b96ed4f1..e5fea7faa7ca 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1216,10 +1216,12 @@ def _one_hot(): def _impl(inputs, attr, params): depth = int(_get_num_param(params, inputs[1])) dtype = attr['T'].name - + on_value = _get_num_param(params, inputs[2]) off_value = _get_num_param(params, inputs[3]) - new_inputs = [inputs[0], tvm.relay.const(on_value, dtype), tvm.relay.const(off_value, dtype)] + new_inputs = [inputs[0], \ + tvm.relay.const(on_value, dtype), \ + tvm.relay.const(off_value, dtype)] return AttrCvt('one_hot', ignores=['TI'], extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b43608ad5cf2..38ce653c0716 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -752,7 +752,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): def one_hot(indices, on_value, off_value, depth, axis, dtype): """ Returns a one-hot tensor where the locations repsented by indices take value on_value, - other locations take value off_value. + other locations take value off_value. Final dimension is x depth x . Parameters diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 721be9f8149b..b39c282b1d96 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2519,7 +2519,14 @@ Array OneHotCompute(const Attrs& attrs, const Target& target) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array{ topi::one_hot(inputs[0], inputs[1](), inputs[2](), param->depth, param->axis, param->dtype) }; + return Array { + topi::one_hot(inputs[0], + inputs[1](), + inputs[2](), + param->depth, + param->axis, + param->dtype) + }; } Expr MakeOneHot(Expr indices, diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 4994f39bd527..1622b208c53e 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1285,7 +1285,7 @@ inline Tensor one_hot(const Tensor& indices, return compute(oshape, [&](const Array& iter_vars) { Array indices_indices; for (size_t i = 0; i < iter_vars.size(); i++) { - if ((int)i == true_axis) { + if (static_cast(i) == true_axis) { continue; } diff --git a/topi/python/topi/testing/one_hot.py b/topi/python/topi/testing/one_hot.py index a366e78eab89..99c52be65c74 100644 --- a/topi/python/topi/testing/one_hot.py +++ b/topi/python/topi/testing/one_hot.py @@ -22,7 +22,7 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): """one_hot operator implemented in numpy. Returns a one-hot tensor where the locations repsented by indices take value on_value, - other locations take value off_value. + other locations take value off_value. Final dimension is x depth x . Parameters @@ -65,10 +65,10 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): output_indices = [index for index in np.ndindex(out.shape)] for output_index in output_indices: indices_indices = [] - for i in range(0, len(output_index)): + for i, out_idx in enumerate(output_index): if i == true_axis: continue - indices_indices.append(output_index[i]) + indices_indices.append(out_idx) index = output_index[true_axis] if indices[tuple(indices_indices)] == index: @@ -76,4 +76,4 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): else: out[output_index] = off_value - return out.astype(dtype) \ No newline at end of file + return out.astype(dtype) diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index ee89091d5070..3c7fc9c0dffb 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -522,7 +522,7 @@ def where(condition, x, y): def one_hot(indices, on_value, off_value, depth, axis, dtype): """ Returns a one-hot tensor where the locations repsented by indices take value on_value, - other locations take value off_value. + other locations take value off_value. Final dimension is x depth x . Parameters From 284e352ef25fb3221697a0e45b71c9800c555d99 Mon Sep 17 00:00:00 2001 From: Jon Date: Wed, 21 Aug 2019 10:04:08 -0700 Subject: [PATCH 18/19] re-enable tests --- tests/python/relay/test_op_level10.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 46dc25a011be..e828fa30de56 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -336,14 +336,14 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") if __name__ == "__main__": - # test_adaptive_pool2d() - # test_collapse_sum_like() - # test_broadcast_to_like() - # test_slice_like() - # test_reverse_reshape() - # test_batch_matmul() - # test_shape_of() - # test_sequence_mask() - # test_ndarray_size() + test_adaptive_pool2d() + test_collapse_sum_like() + test_broadcast_to_like() + test_slice_like() + test_reverse_reshape() + test_batch_matmul() + test_shape_of() + test_sequence_mask() + test_ndarray_size() test_one_hot() From a7fdb2acb922da3dc6a24e6a1ace59403c51a0bf Mon Sep 17 00:00:00 2001 From: Jon Date: Thu, 22 Aug 2019 11:18:02 -0700 Subject: [PATCH 19/19] Add one_hot to mxnet converter --- python/tvm/relay/frontend/mxnet.py | 9 +++++++++ tests/python/frontend/mxnet/test_forward.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 9d82671e5534..36c4fb895874 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -896,6 +896,14 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): ret.append(_op.stack(inputs, axis=0)) return ret +def _mx_one_hot(inputs, attrs): + indices = inputs[0].astype('int32') + depth = attrs.get_int('depth', 0) + dtype = attrs.get_str('dtype', 'int32') + on_value = tvm.relay.const(attrs.get_float('on_value', 1.0), dtype) + off_value = tvm.relay.const(attrs.get_float('off_value', 0.0), dtype) + return _op.one_hot(indices, on_value, off_value, depth, -1, dtype) + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free @@ -1041,6 +1049,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "LinearRegressionOutput" : _mx_linear_regression_output, "smooth_l1" : _mx_smooth_l1, "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim, + "one_hot" : _mx_one_hot, # vision "_contrib_BilinearResize2D" : _mx_resize, "_contrib_MultiBoxPrior" : _mx_multibox_prior, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index a4a514ea7474..90b425ff22ed 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -778,6 +778,25 @@ def verify(shape, axis=-1): verify((2, 5), axis=0) verify((2, 5, 6)) +def test_forward_one_hot(): + def verify(indices_shape, depth, on_value, off_value, dtype): + x = np.random.randint(0, 5, size=indices_shape) + ref_res = mx.nd.one_hot(mx.nd.array(x), depth, on_value, off_value, dtype) + mx_sym = mx.sym.one_hot(mx.sym.var("x"), depth, on_value, off_value, dtype) + shape_dict = {"x": x.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x.astype("float32")) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) + verify((3,), 3, 1, 0, "int32") + verify((3,), 3, 1.0, 0.0, "float32") + verify((2, 2), 5, 2, -2, "int32") + verify((2, 2), 5, 0.5, -0.5, "float32") + verify((3, 2, 4, 5), 6, 1, 0, "int32") + verify((3, 2, 4, 5), 6, 1.0, 0.0, "float32") + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -825,3 +844,4 @@ def verify(shape, axis=-1): test_forward_contrib_div_sqrt_dim() test_forward_batch_norm() test_forward_layer_norm() + test_forward_one_hot()