From 1d4850d40a38d63e8715d062a530e659be71e7f4 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Sun, 5 Jan 2020 19:50:09 -0800 Subject: [PATCH] [Topi]Allow empty tensor for reshape, tile and strided_slice (#4618) * Support empty tensor * Fix schedule * Refactor * Minor fix * Fix pylint * Merge cpp and python is_empty_shape --- src/relay/op/tensor/transform.cc | 4 ++ topi/include/topi/detail/tensor_utils.h | 55 ++++++++++++++++++++++ topi/include/topi/transform.h | 59 ++++++++++++++++-------- topi/python/topi/arm_cpu/injective.py | 5 +- topi/python/topi/cpp/__init__.py | 1 + topi/python/topi/cpp/util.py | 21 +++++++++ topi/python/topi/cuda/injective.py | 4 +- topi/python/topi/util.py | 18 +++++++- topi/python/topi/x86/injective.py | 5 +- topi/src/topi.cc | 8 ++++ topi/tests/python/test_topi_transform.py | 3 ++ 11 files changed, 159 insertions(+), 24 deletions(-) create mode 100644 topi/include/topi/detail/tensor_utils.h create mode 100644 topi/python/topi/cpp/util.py diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ca40d2a71915..5885a00cf2f0 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -888,6 +888,7 @@ bool TakeRel(const Array& types, CHECK(data != nullptr); const auto* indices = types[1].as(); CHECK(indices != nullptr); + CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; const auto param = attrs.as(); CHECK(param != nullptr); @@ -1648,6 +1649,9 @@ bool SqueezeRel(const Array& types, // if axes is None, squeeze all axes of dimension 1 if (!param->axis.defined()) { for (const auto& e : data->shape) { + if (!e.as()) { + LOG(FATAL) << "axis needs to be defined for dynamic input."; + } const int64_t* axis_ptr = as_const_int(e); CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete"; if (*axis_ptr != 1) { diff --git a/topi/include/topi/detail/tensor_utils.h b/topi/include/topi/detail/tensor_utils.h new file mode 100644 index 000000000000..ede3f883aab5 --- /dev/null +++ b/topi/include/topi/detail/tensor_utils.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tensor_utils.h + * \brief Utility functions for handling tensor + */ +#ifndef TOPI_DETAIL_TENSOR_UTILS_H_ +#define TOPI_DETAIL_TENSOR_UTILS_H_ + + +namespace topi { +namespace detail { +using namespace tvm; + +/*! + * \brief Check whether input shape has dimension of size 0; + * + * \param x Input shape + * + * \return True if the input shape is empty. + */ +inline bool is_empty_shape(const Array& x) { + bool is_empty = false; + for (const auto& dim : x) { + if (auto int_dim = dim.as()) { + if (int_dim->value == 0) { + is_empty = true; + break; + } + } + } + return is_empty; +} + +} // namespace detail +} // namespace topi +#endif // TOPI_DETAIL_TENSOR_UTILS_H_ + diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index e68b7a5a0b3c..9a66280bd4c5 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -34,6 +34,7 @@ #include "topi/tags.h" #include "topi/detail/ravel_unravel.h" #include "topi/detail/constant_utils.h" +#include "topi/detail/tensor_utils.h" #include "tvm/operation.h" #include "tvm/expr_operator.h" #include "tvm/data_layout.h" @@ -207,16 +208,28 @@ inline Tensor reshape(const Tensor& x, std::string name = "T_reshape", std::string tag = kInjective) { auto x_shape = x->shape; - Array newshape_int32; + Array target_shape; for (const auto &ele : newshape) { - newshape_int32.push_back(cast(DataType::Int(32), ele)); + if (ele.as()) { + target_shape.push_back(cast(DataType::Int(32), ele)); + } else { + target_shape.push_back(ele); + } + } + + if (is_empty_shape(target_shape)) { + return compute(target_shape, + [&](const Array &indices) { return tvm::cast(x->dtype, 0); }, + name, tag); + } else { + return compute( + target_shape, [&](const Array& indices) { + return x(UnravelIndex( + RavelIndex(Array{indices.begin(), indices.end()}, target_shape), + x_shape)); + }, name, tag); } - return compute( - newshape_int32, [&](const Array& indices) { - return x(UnravelIndex(RavelIndex(Array{indices.begin(), indices.end()}, newshape_int32), - x_shape)); - }, name, tag); } /*! @@ -556,7 +569,7 @@ inline Tensor strided_slice(const Tensor& x, int interval = std::abs(end_i - begin_i); int slice_size = static_cast((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); - CHECK(stride_vec[i] < 0 ? (end_i < begin_i) : (begin_i < end_i)) + CHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] << "] is invalid for axis=" << i; @@ -938,18 +951,24 @@ inline Tensor tile(const Tensor& x, for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]); - return compute( - new_shape, [&](const Array& indices) { - Array idx; - if (ndim >= rdim) { - for (size_t i = 0; i < ndim; ++i) - idx.push_back(indexmod(indices[i], x->shape[i])); - } else { - for (size_t i = 0; i < ndim; ++i) - idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); - } - return x(idx); - }, name, tag); + if (is_empty_shape(new_shape)) { + return compute(new_shape, + [&](const Array& indices) { return tvm::cast(x->dtype, 0);}, + name, tag); + } else { + return compute( + new_shape, [&](const Array& indices) { + Array idx; + if (ndim >= rdim) { + for (size_t i = 0; i < ndim; ++i) + idx.push_back(indexmod(indices[i], x->shape[i])); + } else { + for (size_t i = 0; i < ndim; ++i) + idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); + } + return x(idx); + }, name, tag); + } } /*! diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 3727754988f3..0b6a16d37d1a 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -18,6 +18,7 @@ """Schedule for pooling operators""" import tvm from .. import generic +from ..util import is_empty_shape @generic.schedule_injective_from_existing.register(["arm_cpu"]) def schedule_injective_from_existing(sch, out): @@ -68,7 +69,9 @@ def schedule_injective(outs): (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8) s[x].vectorize(ii) tvm.schedule.AutoInlineInjective(s) - schedule_injective_from_existing(s, x) + + if not is_empty_shape(x.shape): + schedule_injective_from_existing(s, x) return s @generic.schedule_concatenate.register(["arm_cpu"]) diff --git a/topi/python/topi/cpp/__init__.py b/topi/python/topi/cpp/__init__.py index c0344170b035..9a11149417d3 100644 --- a/topi/python/topi/cpp/__init__.py +++ b/topi/python/topi/cpp/__init__.py @@ -24,3 +24,4 @@ from . import generic from . import rocm from . import image +from . import util diff --git a/topi/python/topi/cpp/util.py b/topi/python/topi/cpp/util.py new file mode 100644 index 000000000000..90264bc89170 --- /dev/null +++ b/topi/python/topi/cpp/util.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI for TOPI utility functions""" + +from tvm._ffi.function import _init_api_prefix + +_init_api_prefix("topi.cpp.util", "topi.util") diff --git a/topi/python/topi/cuda/injective.py b/topi/python/topi/cuda/injective.py index a6ec85377a13..0a131148be68 100644 --- a/topi/python/topi/cuda/injective.py +++ b/topi/python/topi/cuda/injective.py @@ -18,6 +18,7 @@ """Schedule for composition of injective operator""" import tvm from .. import generic, util +from ..util import is_empty_shape @generic.schedule_injective_from_existing.register(["cuda", "gpu"]) def schedule_injective_from_existing(sch, out): @@ -79,7 +80,8 @@ def schedule_injective(outs): tvm.schedule.AutoInlineInjective(s) for out in outs: - schedule_injective_from_existing(s, out) + if not is_empty_shape(out.shape): + schedule_injective_from_existing(s, out) return s schedule_elemwise = schedule_injective diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index e25e85dac05e..079dda5d0b0e 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -21,7 +21,7 @@ import tvm from tvm.api import layout, bijective_layout -from . import tag +from . import tag, cpp class InvalidShapeError(ValueError): """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" @@ -417,3 +417,19 @@ def make_idx(b, e, s, z, i): (b - i) // tvm.abs(s), (i - b) // s) return tvm.if_then_else(tvm.expr.Or(bc, ec), 88, ss) + + +def is_empty_shape(shape): + """Check whether an input shape has dimesion with size 0. + + Parameter + --------- + shape : list of Expr + Input shape + + Returns + ------- + is_empty: bool + Whether input shape is empty or has dimesion with size 0. + """ + return cpp.util.is_empty_shape(shape) diff --git a/topi/python/topi/x86/injective.py b/topi/python/topi/x86/injective.py index 5bcb17921b79..8c97214ea4bb 100644 --- a/topi/python/topi/x86/injective.py +++ b/topi/python/topi/x86/injective.py @@ -19,6 +19,7 @@ from __future__ import absolute_import as _abs import tvm from .. import generic +from ..util import is_empty_shape @generic.schedule_injective_from_existing.register(["cpu"]) def schedule_injective_from_existing(sch, out): @@ -65,7 +66,9 @@ def schedule_injective(outs): x = outs[0] s = tvm.create_schedule([x.op for x in outs]) tvm.schedule.AutoInlineInjective(s) - schedule_injective_from_existing(s, x) + + if not is_empty_shape(x.shape): + schedule_injective_from_existing(s, x) return s @generic.schedule_concatenate.register(["cpu"]) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 11a90215d71f..779822ff8bbf 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -72,6 +72,8 @@ #include #include +#include + namespace topi { using namespace tvm; @@ -740,6 +742,12 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_l2_normalize") *rv = topi::cuda::schedule_l2_normalize(args[0], args[1]); }); +/* Utility functions */ +TVM_REGISTER_GLOBAL("topi.util.is_empty_shape") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::detail::is_empty_shape(args[0]); + }); + /*! \brief Builder function for instantiating schedules. */ using FTVMScheduleBuilder = std::function< tvm::Schedule(const tvm::Target& target, const tvm::Array& outs)>; diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 4dc485836ee6..e87c6db8856a 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -555,6 +555,7 @@ def test_strided_slice(): verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2]) verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) + verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3]) def test_strided_set(): verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2]) @@ -596,6 +597,7 @@ def test_reshape(): verify_reshape((4, 2, 3, 4), (2, 4, 12)) verify_reshape((4, 2, 3, 4), (2, 48)) verify_reshape((16, ), (2, 2, 2, 2)) + verify_reshape((4, 0), (2, 0, 2)) def test_where(): @@ -718,6 +720,7 @@ def test_tile(): verify_tile((3, 2), (2, 3)) verify_tile((3, 2, 5), (2,)) verify_tile((3, ), (2, 3, 3)) + verify_tile((4, 0), (5,)) def test_layout_transform(): in_shape = (1, 32, 8, 8)