Skip to content

Commit

Permalink
[Topi]Allow empty tensor for reshape, tile and strided_slice (apache#…
Browse files Browse the repository at this point in the history
…4618)

* Support empty tensor

* Fix schedule

* Refactor

* Minor fix

* Fix pylint

* Merge cpp and python is_empty_shape
  • Loading branch information
kevinthesun authored and zhiics committed Mar 2, 2020
1 parent 6bcf775 commit 1d4850d
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 24 deletions.
4 changes: 4 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ bool TakeRel(const Array<Type>& types,
CHECK(data != nullptr);
const auto* indices = types[1].as<TensorTypeNode>();
CHECK(indices != nullptr);
CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);

Expand Down Expand Up @@ -1648,6 +1649,9 @@ bool SqueezeRel(const Array<Type>& types,
// if axes is None, squeeze all axes of dimension 1
if (!param->axis.defined()) {
for (const auto& e : data->shape) {
if (!e.as<IntImm>()) {
LOG(FATAL) << "axis needs to be defined for dynamic input.";
}
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
if (*axis_ptr != 1) {
Expand Down
55 changes: 55 additions & 0 deletions topi/include/topi/detail/tensor_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tensor_utils.h
* \brief Utility functions for handling tensor
*/
#ifndef TOPI_DETAIL_TENSOR_UTILS_H_
#define TOPI_DETAIL_TENSOR_UTILS_H_


namespace topi {
namespace detail {
using namespace tvm;

/*!
* \brief Check whether input shape has dimension of size 0;
*
* \param x Input shape
*
* \return True if the input shape is empty.
*/
inline bool is_empty_shape(const Array<Expr>& x) {
bool is_empty = false;
for (const auto& dim : x) {
if (auto int_dim = dim.as<IntImm>()) {
if (int_dim->value == 0) {
is_empty = true;
break;
}
}
}
return is_empty;
}

} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_TENSOR_UTILS_H_

59 changes: 39 additions & 20 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "topi/tags.h"
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
#include "topi/detail/tensor_utils.h"
#include "tvm/operation.h"
#include "tvm/expr_operator.h"
#include "tvm/data_layout.h"
Expand Down Expand Up @@ -207,16 +208,28 @@ inline Tensor reshape(const Tensor& x,
std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
Array<Expr> newshape_int32;
Array<Expr> target_shape;

for (const auto &ele : newshape) {
newshape_int32.push_back(cast(DataType::Int(32), ele));
if (ele.as<IntImm>()) {
target_shape.push_back(cast(DataType::Int(32), ele));
} else {
target_shape.push_back(ele);
}
}

if (is_empty_shape(target_shape)) {
return compute(target_shape,
[&](const Array<Var> &indices) { return tvm::cast(x->dtype, 0); },
name, tag);
} else {
return compute(
target_shape, [&](const Array<Var>& indices) {
return x(UnravelIndex(
RavelIndex(Array<Expr>{indices.begin(), indices.end()}, target_shape),
x_shape));
}, name, tag);
}
return compute(
newshape_int32, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape_int32),
x_shape));
}, name, tag);
}

/*!
Expand Down Expand Up @@ -556,7 +569,7 @@ inline Tensor strided_slice(const Tensor& x,
int interval = std::abs(end_i - begin_i);
int slice_size = static_cast<int>((interval
+ std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
CHECK(stride_vec[i] < 0 ? (end_i < begin_i) : (begin_i < end_i))
CHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
<< "] is invalid for axis=" << i;

Expand Down Expand Up @@ -938,18 +951,24 @@ inline Tensor tile(const Tensor& x,
for (size_t i = 0; i < tdim; ++i)
new_shape.push_back(data_shape[i] * reps_shape[i]);

return compute(
new_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indexmod(indices[i], x->shape[i]));
} else {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
}
return x(idx);
}, name, tag);
if (is_empty_shape(new_shape)) {
return compute(new_shape,
[&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0);},
name, tag);
} else {
return compute(
new_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indexmod(indices[i], x->shape[i]));
} else {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
}
return x(idx);
}, name, tag);
}
}

/*!
Expand Down
5 changes: 4 additions & 1 deletion topi/python/topi/arm_cpu/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Schedule for pooling operators"""
import tvm
from .. import generic
from ..util import is_empty_shape

@generic.schedule_injective_from_existing.register(["arm_cpu"])
def schedule_injective_from_existing(sch, out):
Expand Down Expand Up @@ -68,7 +69,9 @@ def schedule_injective(outs):
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
s[x].vectorize(ii)
tvm.schedule.AutoInlineInjective(s)
schedule_injective_from_existing(s, x)

if not is_empty_shape(x.shape):
schedule_injective_from_existing(s, x)
return s

@generic.schedule_concatenate.register(["arm_cpu"])
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from . import generic
from . import rocm
from . import image
from . import util
21 changes: 21 additions & 0 deletions topi/python/topi/cpp/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI for TOPI utility functions"""

from tvm._ffi.function import _init_api_prefix

_init_api_prefix("topi.cpp.util", "topi.util")
4 changes: 3 additions & 1 deletion topi/python/topi/cuda/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Schedule for composition of injective operator"""
import tvm
from .. import generic, util
from ..util import is_empty_shape

@generic.schedule_injective_from_existing.register(["cuda", "gpu"])
def schedule_injective_from_existing(sch, out):
Expand Down Expand Up @@ -79,7 +80,8 @@ def schedule_injective(outs):

tvm.schedule.AutoInlineInjective(s)
for out in outs:
schedule_injective_from_existing(s, out)
if not is_empty_shape(out.shape):
schedule_injective_from_existing(s, out)
return s

schedule_elemwise = schedule_injective
Expand Down
18 changes: 17 additions & 1 deletion topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import tvm
from tvm.api import layout, bijective_layout
from . import tag
from . import tag, cpp

class InvalidShapeError(ValueError):
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
Expand Down Expand Up @@ -417,3 +417,19 @@ def make_idx(b, e, s, z, i):
(b - i) // tvm.abs(s),
(i - b) // s)
return tvm.if_then_else(tvm.expr.Or(bc, ec), 88, ss)


def is_empty_shape(shape):
"""Check whether an input shape has dimesion with size 0.
Parameter
---------
shape : list of Expr
Input shape
Returns
-------
is_empty: bool
Whether input shape is empty or has dimesion with size 0.
"""
return cpp.util.is_empty_shape(shape)
5 changes: 4 additions & 1 deletion topi/python/topi/x86/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from ..util import is_empty_shape

@generic.schedule_injective_from_existing.register(["cpu"])
def schedule_injective_from_existing(sch, out):
Expand Down Expand Up @@ -65,7 +66,9 @@ def schedule_injective(outs):
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
schedule_injective_from_existing(s, x)

if not is_empty_shape(x.shape):
schedule_injective_from_existing(s, x)
return s

@generic.schedule_concatenate.register(["cpu"])
Expand Down
8 changes: 8 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
#include <topi/rocm/softmax.h>
#include <topi/rocm/normalization.h>

#include <topi/detail/tensor_utils.h>

namespace topi {

using namespace tvm;
Expand Down Expand Up @@ -740,6 +742,12 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_l2_normalize")
*rv = topi::cuda::schedule_l2_normalize(args[0], args[1]);
});

/* Utility functions */
TVM_REGISTER_GLOBAL("topi.util.is_empty_shape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::detail::is_empty_shape(args[0]);
});

/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<
tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
Expand Down
3 changes: 3 additions & 0 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])

def test_strided_set():
verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])
Expand Down Expand Up @@ -596,6 +597,7 @@ def test_reshape():
verify_reshape((4, 2, 3, 4), (2, 4, 12))
verify_reshape((4, 2, 3, 4), (2, 48))
verify_reshape((16, ), (2, 2, 2, 2))
verify_reshape((4, 0), (2, 0, 2))


def test_where():
Expand Down Expand Up @@ -718,6 +720,7 @@ def test_tile():
verify_tile((3, 2), (2, 3))
verify_tile((3, 2, 5), (2,))
verify_tile((3, ), (2, 3, 3))
verify_tile((4, 0), (5,))

def test_layout_transform():
in_shape = (1, 32, 8, 8)
Expand Down

0 comments on commit 1d4850d

Please sign in to comment.