diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 71808902aeca..c2c639e374f5 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -110,18 +110,28 @@ template class IterAdapter { public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename std::iterator_traits::value_type; + using pointer = typename std::iterator_traits::pointer; + using reference = typename std::iterator_traits::reference; + using iterator_category = typename std::iterator_traits::iterator_category; + explicit IterAdapter(TIter iter) : iter_(iter) {} - inline IterAdapter& operator++() { // NOLINT(*) - ++iter_; - return *this; - } - inline IterAdapter& operator++(int) { // NOLINT(*) + inline IterAdapter& operator++() { ++iter_; return *this; } - inline IterAdapter operator+(int offset) const { // NOLINT(*) + inline IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } + + template + typename std::enable_if::value, + typename T::difference_type>::type + inline operator-(const IterAdapter& rhs) const { + return iter_ - rhs.iter_; + } + inline bool operator==(IterAdapter other) const { return iter_ == other.iter_; } diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 4a6e845b2f02..051c61be86f7 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -35,16 +35,6 @@ namespace tvm { namespace relay { -template -inline std::vector AsVector(const Array &array) { - std::vector result; - result.reserve(array.size()); - for (const T& ele : array) { - result.push_back(ele); - } - return result; -} - /*! Quick helper macro * - Expose a positional make function to construct the node. * - Register op to the registry. diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 647e4d0f4f90..d655665f2083 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -229,7 +229,7 @@ bool ArgReduceRel(const Array& types, const auto* data = types[0].as(); if (data == nullptr) return false; CHECK(static_cast(data->shape.size()) != 0); - std::vector&& in_shape = AsVector(data->shape); + std::vector in_shape(data->shape.begin(), data->shape.end()); const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); @@ -254,7 +254,7 @@ bool ReduceRel(const Array& types, CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; - std::vector&& in_shape = AsVector(data->shape); + std::vector in_shape(data->shape.begin(), data->shape.end()); const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 0b501e2ff119..a7aeb03d5bdd 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -265,7 +265,7 @@ bool ConcatenateRel(const Array& types, } axis = axis < 0 ? ndim + axis : axis; // Calculate shape - std::vector&& oshape = AsVector(first->shape); + std::vector oshape(first->shape.begin(), first->shape.end()); IndexExpr &concat_dim = oshape[axis]; bool has_any = false; if (concat_dim.as()) { @@ -834,7 +834,7 @@ bool TakeRel(const Array& types, CHECK(param != nullptr); if (!param->axis.defined()) { - std::vector&& oshape = AsVector(indices->shape); + std::vector oshape(indices->shape.begin(), indices->shape.end()); reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); return true; } @@ -1990,7 +1990,7 @@ bool SplitRel(const Array& types, << "indices_or_sections need to be able to divide input.shape[axis]"; std::vector fields; for (int i = 0; i < sections->value; ++i) { - std::vector&& oshape = AsVector(data->shape); + std::vector oshape(data->shape.begin(), data->shape.end()); oshape[axis] /= int32_t(sections->value); auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); @@ -2003,7 +2003,7 @@ bool SplitRel(const Array& types, for (unsigned int i = 0; i < indices.size(); ++i) { CHECK(reporter->Assert(IndexExpr(indices[i]) > begin)) << "indices_or_sections need to be a sorted ascending list"; - std::vector&& oshape = AsVector(data->shape); + std::vector oshape(data->shape.begin(), data->shape.end()); oshape[axis] = IndexExpr(indices[i]) - begin; begin = IndexExpr(indices[i]); auto vec_type = TensorTypeNode::make(oshape, data->dtype); @@ -2011,7 +2011,7 @@ bool SplitRel(const Array& types, } CHECK(reporter->Assert(begin < data->shape[axis])) << "The sum of sections must match the input.shape[axis]"; - std::vector&& oshape = AsVector(data->shape); + std::vector oshape(data->shape.begin(), data->shape.end()); oshape[axis] = data->shape[axis] - begin; auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); @@ -2105,9 +2105,9 @@ bool SliceLikeRel(const Array& types, const auto param = attrs.as(); CHECK(param != nullptr); - const Array dshape = data->shape; - const Array target_shape = target->shape; - std::vector&& oshape = AsVector(dshape); + const Array& dshape = data->shape; + const Array& target_shape = target->shape; + std::vector oshape(dshape.begin(), dshape.end()); if (!param->axes.defined()) { for (size_t i = 0; i < dshape.size(); ++i) { diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index 0a1d9614976e..d525b428a2ec 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -53,7 +53,7 @@ bool YoloReorgRel(const Array& types, CHECK(param != nullptr); CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension."; - std::vector&& oshape = AsVector(data->shape); + std::vector oshape(data->shape.begin(), data->shape.end()); oshape[1] = oshape[1] * param->stride * param->stride; oshape[2] = oshape[2] / param->stride; oshape[3] = oshape[3] / param->stride; diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index c9afedb057d5..005e15969a88 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -17,6 +17,8 @@ * under the License. */ +#include +#include #include #include #include @@ -42,6 +44,13 @@ TEST(Array, Mutate) { CHECK(list2[1].same_as(z)); } +TEST(Array, Iterator) { + using namespace tvm; + Array array{1, 2, 3}; + std::vector vector(array.begin(), array.end()); + CHECK(vector[1].as()->value == 2); +} + TEST(Map, Expr) { using namespace tvm; Var x("x"); @@ -86,6 +95,14 @@ TEST(Map, Mutate) { LOG(INFO) << dict; } +TEST(Map, Iterator) { + using namespace tvm; + Expr a = 1, b = 2; + Map map1{{a, b}}; + std::unordered_map map2(map1.begin(), map1.end()); + CHECK(map2[a].as()->value == 2); +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";