From be80c301e8bd5aaecf7ad6963eebeddc232a81f2 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 15 Aug 2019 20:36:19 +0000 Subject: [PATCH] [Relay] Refactor - Move infer types to a header file. --- src/relay/op/tensor/transform.cc | 1100 +---------------------------- src/relay/op/tensor/transform.h | 1116 ++++++++++++++++++++++++++++++ 2 files changed, 1117 insertions(+), 1099 deletions(-) create mode 100644 src/relay/op/tensor/transform.h diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 03a92b35d3969..2851e7bc7f3a4 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -37,6 +37,7 @@ #include "../op_common.h" #include "../../../arithmetic/compute_expr.h" #include "../../pass/alter_op_layout.h" +#include "transform.h" namespace tvm { namespace relay { @@ -45,24 +46,6 @@ using ir::IntImm; // relay.cast TVM_REGISTER_NODE_TYPE(CastAttrs); -bool CastRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - CHECK(types[0].as()) - << "cast: expect input type to be TensorType but get " - << types[0]; - return false; - } - const auto* param = attrs.as(); - reporter->Assign(types[1], TensorTypeNode::make( - data->shape, param->dtype)); - return true; -} - Array CastCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -131,46 +114,6 @@ RELAY_REGISTER_OP("reinterpret") // relay.expand_dims TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); -bool ExpandDimsRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // `types` contains: [data, result] - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - CHECK(types[0].as()) - << "expand_dims: expect input type to be TensorType but get " - << types[0]; - return false; - } - const auto* param = attrs.as(); - const int ndim = static_cast(data->shape.size()); - const int axis = param->axis; - const int num_newaxis = param->num_newaxis; - CHECK(num_newaxis >= 0) - << "expand_dims only accepts `num_newaxis >= 0`" - << ", but got num_newaxis = " << num_newaxis; - CHECK(-ndim - 1 <= axis && axis <= ndim) - << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - const int pivot = axis < 0 ? ndim + axis + 1 : axis; - std::vector oshape; - oshape.reserve(ndim + num_newaxis); - for (int i = 0; i < pivot; ++i) { - oshape.emplace_back(data->shape[i]); - } - for (int i = 0; i < num_newaxis; ++i) { - oshape.emplace_back(1); - } - for (int i = pivot; i < ndim; ++i) { - oshape.emplace_back(data->shape[i]); - } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); - return true; -} - Array ExpandDimsCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -210,86 +153,6 @@ RELAY_REGISTER_OP("expand_dims") // relay.concatenate TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); -bool ConcatenateRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // types: [data, result] - CHECK_EQ(types.size(), 2); - /* If we receive a tuple we can continue, if we receive - * anything but an incomplete type we should signal an - * error. - */ - const auto* tensor_tuple = types[0].as(); - if (tensor_tuple == nullptr) { - throw relay::Error( - RELAY_ERROR( - "concatenate requires a tuple of tensors as the first argument, found " - << PrettyPrint(types[0]))); - } else if (types[0].as() != nullptr) { - return false; - } - - const auto* param = attrs.as(); - if (tensor_tuple->fields[0].as()) { - return false; - } - const auto& first = Downcast(tensor_tuple->fields[0]); - // Sanity check: ndim and dtype. - const int ndim = static_cast(first->shape.size()); - const DataType dtype = first->dtype; - - for (const Type& ele : tensor_tuple->fields) { - if (ele.as()) { - return false; - } - - const auto& e = Downcast(ele); - - int e_ndim = static_cast(e->shape.size()); - const DataType& e_dtype = e->dtype; - if (e_ndim != ndim) { - throw relay::Error("relay.concatenate requires all tensors have the same ndim"); - } - if (e_dtype != dtype) { - throw relay::Error("relay.concatenate requires all tensors have the same dtype"); - } - } - // Sanity check: axis - int axis = param->axis; - if (!(-ndim <= axis && axis < ndim)) { - throw relay::Error(RELAY_ERROR( - "concatenate only accepts `axis` in [-ndim, ndim)" << - ", but got axis = " << axis << - ", and ndim = " << ndim)); - } - axis = axis < 0 ? ndim + axis : axis; - // Calculate shape - std::vector oshape(first->shape.begin(), first->shape.end()); - IndexExpr &concat_dim = oshape[axis]; - bool has_any = false; - if (concat_dim.as()) { - has_any = true; - } else { - for (int i = 1; i < static_cast(tensor_tuple->fields.size()); ++i) { - const auto& e = Downcast(tensor_tuple->fields[i]); - if (e->shape[axis].as()) { - has_any = true; - break; - } - concat_dim += e->shape[axis]; - } - } - - if (has_any) { - concat_dim = Any::make(); - } - - auto rtype = TensorTypeNode::make(oshape, dtype); - reporter->Assign(types[1], rtype); - return true; -} - Array ConcatenateCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -365,53 +228,6 @@ RELAY_REGISTER_OP("concatenate") TVM_REGISTER_NODE_TYPE(StackAttrs); -bool StackRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // types: [data, result] - CHECK_EQ(types.size(), 2); - const auto* tensor_tuple = types[0].as(); - if (tensor_tuple == nullptr) { - CHECK(types[0].as()) - << "cast: expect input type to be TupleType but get " - << types[0]; - return false; - } - const auto* param = attrs.as(); - const auto& first = Downcast(tensor_tuple->fields[0]); - // Sanity check: ndim and dtype. - const int ndim = static_cast(first->shape.size()); - const DataType dtype = first->dtype; - for (const Type& ele : tensor_tuple->fields) { - const auto& e = Downcast(ele); - int e_ndim = static_cast(e->shape.size()); - const DataType& e_dtype = e->dtype; - CHECK_EQ(e_ndim, ndim) << "relay.stack requires all tensors have the same ndim"; - CHECK_EQ(e_dtype, dtype) << "relay.stack requires all tensors have the same dtype"; - } - // Sanity check: axis - int axis = param->axis; - CHECK(-ndim <= axis && axis < ndim) - << "stack only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; - axis = axis < 0 ? ndim + axis + 1 : axis; - // Calculate shape - std::vector oshape; - oshape.reserve(ndim + 1); - const int stack_dim = static_cast(tensor_tuple->fields.size()); - for (int i = 0; i < axis; ++i) { - oshape.emplace_back(first->shape[i]); - } - oshape.emplace_back(stack_dim); - for (int i = axis; i < ndim; ++i) { - oshape.emplace_back(first->shape[i]); - } - reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype)); - return true; -} - Array StackCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -451,59 +267,6 @@ RELAY_REGISTER_OP("stack") /* relay.transpose */ TVM_REGISTER_NODE_TYPE(TransposeAttrs); -bool TransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // types: [data, result] - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - CHECK(types[0].as()) - << "transpose: expect input type to be TensorType but get " - << types[0]; - return false; - } - const auto* param = attrs.as(); - const int ndim = data->shape.size(); - const Array& axes = param->axes; - // check dimension match - CHECK(!axes.defined() || static_cast(axes.size()) == ndim) - << "Dimension mismatch: axes has " << axes.size() << " elements" - << ", but data.ndim = " << ndim; - // construct int_axes - std::vector int_axes; - int_axes.reserve(ndim); - // used not defined to check if it is None. - if (!axes.defined()) { - for (int i = ndim - 1; i >= 0; --i) { - int_axes.push_back(i); - } - } else { - std::vector axis_used(ndim, 0); - for (const Integer& e : axes) { - int64_t axis = e; - // sanity check for axis and ndim - CHECK(-ndim <= axis && axis < ndim) - << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - axis = axis < 0 ? axis + ndim : axis; - // sanity check for duplication - CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; - axis_used[axis] = 1; - int_axes.push_back(static_cast(axis)); - } - } - std::vector oshape; - oshape.reserve(ndim); - for (int axis : int_axes) { - oshape.push_back(data->shape[axis]); - } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); - return true; -} - Array TransposeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -543,144 +306,6 @@ RELAY_REGISTER_OP("transpose") /* relay.reshape */ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); -bool ReshapeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // types: [data, result] - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - CHECK(types[0].as()) - << "reshape: expect input type to be TensorType but get " - << types[0]; - return false; - } - - const auto* param = attrs.as(); - Array data_shape; - Array newshape; - if (param->reverse) { - data_shape.assign(data->shape.rbegin(), data->shape.rend()); - newshape.assign(param->newshape.rbegin(), param->newshape.rend()); - } else { - data_shape = data->shape; - newshape = param->newshape; - } - Array oshape; - std::unordered_set used_input_dims; - std::unordered_set used_output_dims; - size_t src_idx = 0; - int infer_idx = -1; - - for (size_t i = 0; i < newshape.size(); ++i) { - int svalue = newshape[i]->value; - // special flag handling for shape inference. - if (svalue > 0) { - oshape.push_back(newshape[i]); - ++src_idx; - } else if (svalue == 0) { - // keep same - CHECK_LT(src_idx, data_shape.size()); - used_input_dims.insert(src_idx); - used_output_dims.insert(oshape.size()); - oshape.push_back(data_shape[src_idx++]); - } else if (svalue == -1) { - // inference based on rest - CHECK_LT(infer_idx, 0) - << "One and only one dim can be inferred"; - infer_idx = i; - oshape.push_back(1); - ++src_idx; - } else if (svalue == -2) { - // copy all remaining dims from source - while (src_idx < data_shape.size()) { - used_input_dims.insert(src_idx); - used_output_dims.insert(oshape.size()); - oshape.push_back(data_shape[src_idx++]); - } - } else if (svalue == -3) { - // merge two dims from source - CHECK_LT(src_idx + 1, data_shape.size()); - used_input_dims.insert(src_idx); - IndexExpr d1 = data_shape[src_idx++]; - used_input_dims.insert(src_idx); - IndexExpr d2 = data_shape[src_idx++]; - used_output_dims.insert(oshape.size()); - oshape.push_back(d1 * d2); - } else if (svalue == -4) { - // split the source dim s into two dims - // read the left dim and then the right dim (either can be -1) - CHECK_LT(i + 2, newshape.size()); - CHECK_LT(src_idx, data_shape.size()); - used_input_dims.insert(src_idx); - IndexExpr d0 = data_shape[src_idx++]; - Integer d1 = newshape[++i]; - Integer d2 = newshape[++i]; - if (d1->value == -1) { - CHECK(d2->value != -1) - << "Split dims cannot both be -1."; - used_output_dims.insert(oshape.size()); - if (d0.as()) { - oshape.push_back(Any::make()); - } else { - oshape.push_back(d0 / d2); - } - used_output_dims.insert(oshape.size()); - oshape.push_back(d2); - } else { - used_output_dims.insert(oshape.size()); - oshape.push_back(d1); - used_output_dims.insert(oshape.size()); - if (d2->value == -1) { - if (d0.as()) { - oshape.push_back(Any::make()); - } else { - oshape.push_back(d0 / d1); - } - } else { - oshape.push_back(d2); - } - } - } - } - - if (infer_idx >= 0) { - IndexExpr infer_dim = 1; - for (size_t i = 0; i < data_shape.size(); ++i) { - if (used_input_dims.count(i) != 0) { - continue; - } - if (data_shape[i].as()) { - infer_dim = Any::make(); - break; - } - infer_dim *= data_shape[i]; - } - if (!infer_dim.as()) { - for (size_t i = 0; i < oshape.size(); ++i) { - if (used_output_dims.count(i) != 0) { - continue; - } - if (oshape[i].as()) { - infer_dim = Any::make(); - break; - } - infer_dim /= oshape[i]; - } - } - oshape.Set(infer_idx, infer_dim); - } - - if (param->reverse) { - reporter->Assign(types[1], TensorTypeNode::make( - Array(oshape.rbegin(), oshape.rend()), data->dtype)); - } else { - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); - } - return true; -} - Array ReshapeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -761,46 +386,15 @@ Example:: .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); - -/*! -* \brief ReshapeLikeRel User defined type constraint function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return False if the relation has not been resolved, it might be resolved later. -* True if this relation has been resolved. -*/ -bool ReshapeLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - if (data == nullptr) { - return false; - } - const auto* reshape_like = types[1].as(); - if (reshape_like == nullptr) { - return false; - } - CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) - << "Reshape inputs size should be compatible."; - reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype)); - return true; -} - - Expr MakeReshapeLike(Expr data, Expr shape_like) { static const Op& op = Op::Get("reshape_like"); return CallNode::make(op, {data, shape_like}, Attrs(), {}); } - TVM_REGISTER_API("relay.op._make.reshape_like") .set_body_typed(MakeReshapeLike); - RELAY_REGISTER_OP("reshape_like") .describe(R"code(Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes @@ -816,53 +410,9 @@ the input array into an output array with the same shape as the second input arr .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); - // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); -bool TakeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // `types` contains: [data, indices, result] - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - CHECK(data != nullptr); - const auto* indices = types[1].as(); - CHECK(indices != nullptr); - const auto param = attrs.as(); - CHECK(param != nullptr); - - if (!param->axis.defined()) { - std::vector oshape(indices->shape.begin(), indices->shape.end()); - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); - return true; - } - - std::vector oshape; - const auto ndim_data = static_cast(data->shape.size()); - const auto ndim_indices = static_cast(indices->shape.size()); - int axis = static_cast(param->axis->value); - if (axis < 0) axis += ndim_data; - CHECK_LE(axis, ndim_data) - << "axis should be with in data shape" - << ", but got = " << axis; - - oshape.reserve(ndim_data - 1 + ndim_indices); - for (int i = 0; i < axis; ++i) { - oshape.emplace_back(data->shape[i]); - } - for (int i = 0; i < ndim_indices; ++i) { - oshape.emplace_back(indices->shape[i]); - } - for (int i = axis+1; i < ndim_data; ++i) { - oshape.emplace_back(data->shape[i]); - } - - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); - return true; -} - Array TakeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -922,34 +472,9 @@ Examples:: .set_attr("FTVMCompute", TakeCompute) .set_attr("TOpPattern", kInjective); - // Init ops TVM_REGISTER_NODE_TYPE(InitOpAttrs); -bool FullRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const InitOpAttrs* param = attrs.as(); - const auto* fill_value = types[0].as(); - if (fill_value == nullptr) { - return false; - } - - DataType out_dtype = param->dtype; - if (out_dtype.bits() == 0) { - out_dtype = fill_value->dtype; - } - - CHECK_EQ(fill_value->shape.size(), 0) - << "Fill value should be a scalar but has dimension " - << fill_value->shape.size() << "."; - - reporter->Assign(types[1], TensorTypeNode::make(param->shape, out_dtype)); - return true; -} - Array FullCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -983,17 +508,6 @@ RELAY_REGISTER_OP("full") .set_attr("FTVMCompute", FullCompute) .set_attr("TOpPattern", kElemWise); -bool InitOpRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 1); - const InitOpAttrs* param = attrs.as(); - - reporter->Assign(types[0], TensorTypeNode::make(param->shape, param->dtype)); - return true; -} - Expr MakeZeros(Array shape, DataType dtype) { auto attrs = make_node(); @@ -1036,28 +550,6 @@ RELAY_REGISTER_OP("ones") .set_support_level(3) .add_type_rel("InitOp", InitOpRel); -bool FullLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - if (data == nullptr) { - return false; - } - const auto* fill_value = types[1].as(); - if (fill_value == nullptr) { - return false; - } - - CHECK_EQ(fill_value->shape.size(), 0) - << "The fill value should be a scalar but here it has dimension " - << fill_value->shape.size() << "."; - - reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype)); - return true; -} - Array FullLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -1090,44 +582,6 @@ and type as the input array. // arange operator TVM_REGISTER_NODE_TYPE(ArangeAttrs); -double ToScalar(const runtime::NDArray& array) { - if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) { - return reinterpret_cast(array->data)[0]; - } else { - return reinterpret_cast(array->data)[0]; - } -} - -bool ArangeRel(const Array& types, - int num_inputs, - const Attrs& raw_attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 4); - const ArangeAttrs* attrs = raw_attrs.as(); - const ConstantNode *cstart, *cstop, *cstep; - - reporter->Assign(types[0], types[1]); - reporter->Assign(types[1], types[2]); - reporter->Assign(types[2], TensorTypeNode::make({}, attrs->dtype)); - - if ((cstart = attrs->start.as()) && - (cstop = attrs->stop.as()) && - (cstep = attrs->step.as())) { - double start = ToScalar(cstart->data); - double stop = ToScalar(cstop->data); - double step = ToScalar(cstep->data); - int32_t num_elem = static_cast(std::ceil((stop - start) / step)); - CHECK_GT(num_elem, 0) - << "Invalid arange attributes (start, stop, step): " << attrs->start - << ", " << attrs->stop << ", " << attrs->step; - reporter->Assign(types[3], TensorTypeNode::make({num_elem}, attrs->dtype)); - return true; - } else { - reporter->Assign(types[3], TensorTypeNode::make({Any::make()}, attrs->dtype)); - return true; - } -} - inline Tensor DynamicArange(const tvm::Tensor& start, const tvm::Tensor& stop, const tvm::Tensor& step, tvm::Type dtype, std::string name = "tensor", std::string tag = topi::kInjective) { @@ -1193,44 +647,6 @@ RELAY_REGISTER_OP("arange") // repeat operator TVM_REGISTER_NODE_TYPE(RepeatAttrs); -bool RepeatRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // `types` contains: [data, result] - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - CHECK(types[0].as()) - << "repeat: expect input type to be TensorType but get " - << types[0]; - return false; - } - const auto* param = attrs.as(); - const int ndim = static_cast(data->shape.size()); - const int repeats = param->repeats; - const int axis = param->axis; - CHECK(repeats >= 1) - << "repeat only accepts `repeats >= 1`" - << ", but got repeats = " << repeats; - CHECK(-ndim - 1 <= axis && axis <= ndim) - << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - const int pivot = axis < 0 ? ndim + axis : axis; - std::vector oshape; - oshape.reserve(ndim + repeats); - for (int i = 0; i < pivot; ++i) { - oshape.emplace_back(data->shape[i]); - } - oshape.emplace_back(data->shape[pivot] * repeats); - for (int i = pivot + 1; i < ndim; ++i) { - oshape.emplace_back(data->shape[i]); - } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); - return true; -} - Array RepeatCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -1270,67 +686,6 @@ RELAY_REGISTER_OP("repeat") // tile operator TVM_REGISTER_NODE_TYPE(TileAttrs); -bool TileRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // `types` contains: [data, result] - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - CHECK(types[0].as()) - << "tile: expect input type to be TensorType but get " - << types[0]; - return false; - } - const auto* param = attrs.as(); - const size_t ndim = data->shape.size(); - const Array& reps = param->reps; - // check dimension match - CHECK(reps.defined()) - << "repetition array is not defined. data.ndim = " << ndim; - const size_t rndim = reps.size(); - for (size_t i = 0; i < rndim; ++i) { - if (const tvm::ir::IntImm* val = reps[i].as()) { - CHECK_GT(val->value, 0) - << "Tile reps value should always be larger than 0, but get: " << val->value; - } - } - size_t tndim = (ndim > rndim) ? ndim : rndim; - // re-construct data shape or reps shape - std::vector data_shape; - std::vector reps_shape; - data_shape.reserve(tndim); - reps_shape.reserve(tndim); - if (ndim == rndim) { - for (size_t i = 0; i < tndim; ++i) { - data_shape.emplace_back(data->shape[i]); - reps_shape.emplace_back(reps[i]); - } - } else if (ndim > rndim) { - for (size_t i = 0; i < ndim; ++i) - data_shape.emplace_back(data->shape[i]); - for (size_t i = 0; i < (ndim - rndim); ++i) - reps_shape.emplace_back(1); - for (size_t i = 0; i < rndim; ++i) - reps_shape.emplace_back(reps[i]); - } else { - for (size_t i = 0; i < rndim; ++i) - reps_shape.emplace_back(reps[i]); - for (size_t i = 0; i < (rndim - ndim); ++i) - data_shape.emplace_back(1); - for (size_t i = 0; i < ndim; ++i) - data_shape.emplace_back(data->shape[i]); - } - std::vector oshape; - oshape.reserve(tndim); - for (size_t i = 0; i < tndim; ++i) { - oshape.emplace_back(data_shape[i] * reps_shape[i]); - } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); - return true; -} - Array TileCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -1368,30 +723,6 @@ RELAY_REGISTER_OP("tile") // reverse operator TVM_REGISTER_NODE_TYPE(ReverseAttrs); -bool ReverseRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // `types` contains: [data, result] - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - CHECK(types[0].as()) - << "reverse: expect input type to be TensorType but get " - << types[0]; - return false; - } - const auto* param = attrs.as(); - const int ndim = static_cast(data->shape.size()); - const int axis = param->axis; - CHECK(-ndim <= axis && axis < ndim) - << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - reporter->Assign(types[1], types[0]); - return true; -} - Array ReverseCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -1427,39 +758,6 @@ RELAY_REGISTER_OP("reverse") .set_attr("TOpPattern", kInjective); // where operator -bool WhereRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 4U); - const auto* condition = types[0].as(); - const auto* x = types[1].as(); - const auto* y = types[2].as(); - CHECK(condition != nullptr && x != nullptr && y != nullptr); - - const auto& cond_shape = condition->shape; - const auto& x_shape = x->shape; - const auto& y_shape = y->shape; - CHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size"; - - if (cond_shape.size() != x_shape.size()) { - CHECK_EQ(cond_shape.size(), 1) - << "Shape of condition " << condition->shape - << " must be either equal to x or has dimension of 1."; - } - for (size_t i = 0; i < x_shape.size(); i++) { - CHECK(reporter->AssertEQ(x_shape[i], y_shape[i])) - << "x and y must have the same shape: " << x_shape << " vs " << y_shape; - - if (i < cond_shape.size()) { - CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) - << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; - } - } - reporter->Assign(types[3], TensorTypeNode::make(x_shape, x->dtype)); - return true; -} - // Positional relay function to create where operator. Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) { static const Op& op = Op::Get("where"); @@ -1500,7 +798,6 @@ Examples:: cond = [[0, 1], [-1, 0]] where(cond, x, y) = [[5, 2], [3, 8]] - cond = [1, 0] where(cond, x, y) = [[1, 2], [7, 8]] @@ -1514,7 +811,6 @@ Examples:: .set_attr("FTVMCompute", WhereCompute) .set_attr("TOpPattern", kBroadcast); - // Squeeze TVM_REGISTER_NODE_TYPE(SqueezeAttrs); @@ -1529,57 +825,6 @@ Expr MakeSqueeze(Expr data, TVM_REGISTER_API("relay.op._make.squeeze") .set_body_typed(MakeSqueeze); - -bool SqueezeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - return false; - } - const auto* param = attrs.as(); - CHECK(param != nullptr); - std::vector result_shape; - // if axes is None, squeeze all axes of dimension 1 - if (!param->axis.defined()) { - for (const auto& e : data->shape) { - const int64_t* axis_ptr = as_const_int(e); - CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete"; - if (*axis_ptr != 1) { - result_shape.push_back(e); - } - } - } else { - // pair up original shape with a boolean which control whether it will be in the final shape. - std::vector > original_shape; - for (const auto& e : data->shape) { - original_shape.push_back(std::pair(e, true)); - } - for (const auto& e : param->axis) { - int64_t axis_val = e->value; - if (axis_val < 0) { - axis_val += static_cast(original_shape.size()); - } - CHECK_GE(axis_val, 0); - CHECK_LT(axis_val, original_shape.size()); - original_shape.at(axis_val).second = false; - } - for (const auto p : original_shape) { - if (p.second) { - result_shape.push_back(p.first); - } else { - const int64_t* axis_ptr = as_const_int(p.first); - CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor"; - CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1"; - } - } - } - reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype)); - return true; -} - Array SqueezeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -1589,7 +834,6 @@ Array SqueezeCompute(const Attrs& attrs, return { topi::squeeze(inputs[0], param->axis) }; } - RELAY_REGISTER_OP("squeeze") .describe(R"code(Squeeze the input tensor at the dimensions given by axes @@ -1604,18 +848,6 @@ RELAY_REGISTER_OP("squeeze") .set_attr("FTVMCompute", SqueezeCompute) .set_attr("TOpPattern", kInjective); - -// Have no idea how to assert the constraint. -// CollapseSumLike: -> B where BroadCast(A, B) = A -bool CollapseSumLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - reporter->Assign(types[2], types[1]); - return true; -} - Expr MakeCollapseSumLike(Expr data, Expr collapse_type) { static const Op& op = Op::Get("collapse_sum_like"); @@ -1645,21 +877,6 @@ RELAY_REGISTER_OP("collapse_sum_like") .set_attr("FTVMCompute", CollapseSumLikeCompute) .set_attr("TOpPattern", kCommReduce); -// BroadCastTo: -> B where BroadCast(A, B) = B -bool BroadCastToRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - auto ioattrs = attrs.as(); - CHECK(ioattrs); - auto intt = types[0].as(); - if (intt == nullptr) { return false; } - auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype); - reporter->Assign(types[1], type); - return true; -} - Expr MakeBroadCastTo(Expr data, Array shape) { static const Op& op = Op::Get("broadcast_to"); auto attrs = make_node(); @@ -1689,16 +906,6 @@ RELAY_REGISTER_OP("broadcast_to") .set_attr("FTVMCompute", BroadCastToCompute) .set_attr("TOpPattern", kBroadcast); -// BroadCastToLike: -> B where BroadCast(A, B) = B -bool BroadCastToLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - reporter->Assign(types[2], types[1]); - return true; -} - Expr MakeBroadCastToLike(Expr data, Expr broadcast_type) { static const Op& op = Op::Get("broadcast_to_like"); @@ -1728,7 +935,6 @@ RELAY_REGISTER_OP("broadcast_to_like") .set_attr("FTVMCompute", BroadCastToLikeCompute) .set_attr("TOpPattern", kBroadcast); - // Adapter function to make int array. Array GetIntArray(Array arr) { for (size_t i = 0; i < arr.size(); ++i) { @@ -1738,109 +944,8 @@ Array GetIntArray(Array arr) { return Array(arr.node_); } - // strided_slice TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); -bool StridedSliceRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) return false; - - const StridedSliceAttrs *param = attrs.as(); - CHECK(param != nullptr); - - auto dshape = data->shape; - auto num_axis = dshape.size(); - - std::vector stride_vec; - for (Integer i : param->strides) { - CHECK(i.defined()); - stride_vec.push_back(i->value); - } - for (size_t i = stride_vec.size(); i < num_axis; ++i) { - stride_vec.push_back(1); - } - const int64_t max_range = std::numeric_limits::max(); - - std::vector begin_vec; - for (size_t i = 0; i < param->begin.size(); ++i) { - if (!param->begin[i].defined()) { - // value=None - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(param->begin[i]->value); - } - } - for (size_t i = begin_vec.size(); i < num_axis; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } - - std::vector end_vec; - for (size_t i = 0; i < param->end.size(); ++i) { - // allow end to be None - if (!param->end[i].defined()) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else { - end_vec.push_back(param->end[i]->value); - } - } - for (size_t i = end_vec.size(); i < num_axis; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } - - std::vector oshape(dshape.size()); - for (size_t i = 0; i < num_axis; ++i) { - int64_t stride_v = stride_vec[i]; - int64_t begin_v = begin_vec[i]; - int64_t end_v = end_vec[i]; - - if ((stride_v == 1 && - begin_v == 0 && - end_v == max_range) || - (stride_v == -1 && - begin_v == max_range && - end_v == 0)) { - // Quick path, do not slice this dimension. - oshape[i] = dshape[i]; - continue; - } - // Normal path, require the shape to be concrete integer. - // Require concrete integer as symbolic inference of min/max - // can get complicated and not very helpful. - const int64_t* p_dim_size = as_const_int(dshape[i]); - CHECK(p_dim_size) - << "strided_slice requires sliced dimension to be concrete int"; - int64_t dim_size = p_dim_size[0]; - begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; - end_v = (end_v < 0) ? dim_size + end_v : end_v; - - int64_t slice_range, step; - if (stride_v < 0) { - if (end_v < -1) end_v = -1; - CHECK_LT(end_v, begin_v) - << "strided_slice get empty slice at axis " << i; - begin_v = std::min(dim_size - 1, begin_v); - slice_range = begin_v - end_v; - step = -stride_v; - } else { - if (begin_v < 0) begin_v = 0; - CHECK_GE(stride_v, 0); - CHECK_LT(begin_v, end_v) - << "strided_slice get empty slice at axis " << i; - end_v = std::min(dim_size, end_v); - slice_range = end_v - begin_v; - step = stride_v; - } - oshape[i] = make_const(dshape[i].type(), (slice_range + step - 1) / step); - } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); - return true; -} - - Array > StridedSliceInferCorrectLayout( const Attrs& attrs, const Array& new_in_layouts, @@ -1898,7 +1003,6 @@ Array > StridedSliceInferCorrectLayout( return {{layout}, {layout}}; } - // Positional relay function to create StridedSlice operator used by frontend FFI. Expr MakeStridedSlice(Expr data, Array begin, @@ -1923,11 +1027,9 @@ Array StridedSliceCompute(const Attrs& attrs, }; } - TVM_REGISTER_API("relay.op._make.strided_slice") .set_body_typed(MakeStridedSlice); - RELAY_REGISTER_OP("strided_slice") .describe(R"code(Strided slice of an array. @@ -1961,66 +1063,9 @@ Examples:: .set_attr("TOpPattern", kInjective) .set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); - // relay.split TVM_REGISTER_NODE_TYPE(SplitAttrs); -bool SplitRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // `types` contains: [data, result] - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) return false; - CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; - const auto param = attrs.as(); - CHECK(param != nullptr); - auto axis = param->axis; - if (axis < 0) { - axis += data->shape.size(); - } - CHECK_LT(axis, data->shape.size()) - << "axis should be within the input dimension range."; - CHECK_GE(axis, 0) - << "axis should be within the input dimension range."; - - if (const IntImm* sections = param->indices_or_sections.as()) { - CHECK(reporter->Assert(data->shape[axis] % - sections->value == make_zero(Int(64)))) - << "indices_or_sections need to be able to divide input.shape[axis]"; - std::vector fields; - for (int i = 0; i < sections->value; ++i) { - std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] /= int32_t(sections->value); - auto vec_type = TensorTypeNode::make(oshape, data->dtype); - fields.push_back(vec_type); - } - reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); - } else { - auto indices = param->indices_or_sections.as()->data; - auto begin = IndexExpr(make_zero(Int(32))); - std::vector fields; - for (unsigned int i = 0; i < indices.size(); ++i) { - CHECK(reporter->Assert(IndexExpr(indices[i]) > begin)) - << "indices_or_sections need to be a sorted ascending list"; - std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] = IndexExpr(indices[i]) - begin; - begin = IndexExpr(indices[i]); - auto vec_type = TensorTypeNode::make(oshape, data->dtype); - fields.push_back(vec_type); - } - CHECK(reporter->Assert(begin < data->shape[axis])) - << "The sum of sections must match the input.shape[axis]"; - std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] = data->shape[axis] - begin; - auto vec_type = TensorTypeNode::make(oshape, data->dtype); - fields.push_back(vec_type); - reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); - } - return true; -} - Array SplitCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -2076,71 +1121,9 @@ the entries indicate where along axis the array is split. .set_attr("FTVMCompute", SplitCompute) .set_attr("TOpPattern", kInjective); - // relay.slice_like TVM_REGISTER_NODE_TYPE(SliceLikeAttrs); -/*! -* \brief SliceLikeRel User defined type constraint function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return False if the relation has not been resolved, it might be resolved later. -* True if this relation has been resolved. -*/ -bool SliceLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - if (data == nullptr) { - return false; - } - - const auto* target = types[1].as(); - if (target == nullptr) { - return false; - } - - const auto param = attrs.as(); - CHECK(param != nullptr); - - const Array& dshape = data->shape; - const Array& target_shape = target->shape; - std::vector oshape(dshape.begin(), dshape.end()); - - if (!param->axes.defined()) { - for (size_t i = 0; i < dshape.size(); ++i) { - if (i < target_shape.size()) { - oshape[i] = target_shape[i]; - CHECK(reporter->Assert(oshape[i] <= dshape[i])) - << "End index of axis " << i << " exceeds input shape: " - << oshape[i] << " vs " << dshape[i]; - } - } - } else { - CHECK(param->axes.size() != 0) << "Axes cannot be empty."; - for (Integer val : param->axes) { - int axis = val->value; - if (axis < 0) { - axis += dshape.size(); - } - CHECK(axis < static_cast(target_shape.size())) - << "Axis " << axis << " exceeds dimension " - << target_shape.size() << " of target_shape."; - oshape[axis] = target_shape[axis]; - CHECK(reporter->Assert(oshape[axis] <= dshape[axis])) - << "End index of axis " << axis << " exceeds input shape: " - << oshape[axis] << " vs " << dshape[axis]; - } - } - - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); - return true; -} - - Expr MakeSliceLike(Expr data, Expr shape_like, Array axes) { @@ -2196,11 +1179,9 @@ Array SliceLikeCompute(const Attrs& attrs, }; } - TVM_REGISTER_API("relay.op._make.slice_like") .set_body_typed(MakeSliceLike); - RELAY_REGISTER_OP("slice_like") .describe(R"code(Slice the first input respect to the second input. )code" TVM_ADD_FILELINE) @@ -2225,29 +1206,6 @@ Array LayoutTransformCompute(const Attrs& attrs, }; } -bool LayoutTransformRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - const auto* data = types[0].as(); - CHECK(data != nullptr); - const LayoutTransformAttrs* params = attrs.as(); - - Layout src_layout(params->src_layout); - Layout dst_layout(params->dst_layout); - - CHECK(src_layout.defined() && dst_layout.defined()) - << "cannot convert from/to undefined layout"; - - auto layout_converter = BijectiveLayoutNode::make(src_layout, dst_layout); - CHECK(layout_converter.defined()) - << "cannot convert from " << params->src_layout << " to " << params->dst_layout; - - const auto& out_shape = layout_converter.ForwardShape(data->shape); - reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype)); - return true; -} - Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout) { @@ -2275,7 +1233,6 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); - /* relay._contrib_reverse_reshape */ Expr MakeReverseReshape(Expr data, Array newshape) { @@ -2311,42 +1268,6 @@ example below:: .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); -// gather_nd operator -bool GatherNDRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // `types` contains: [data, indices, result] - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* indices = types[1].as(); - if (data == nullptr) { - CHECK(types[0].as()) - << "GatherND: expect input data type to be TensorType but get " - << types[0]; - return false; - } - if (indices == nullptr) { - CHECK(types[1].as()) - << "GatherND: expect indices type to be TensorType but get " - << types[1]; - return false; - } - const size_t ndim = data->shape.size(); - const IntImm* mdim = indices->shape[0].as(); - const size_t kdim = indices->shape.size() - 1; - CHECK(size_t(mdim->value) <= ndim) - << "GatherND: indices shape does satisfy."; - - Array oshape; - for (size_t i = 1; i < kdim + 1; ++i) - oshape.push_back(indices->shape[i]); - for (size_t i = mdim->value; i < ndim; ++i) - oshape.push_back(data->shape[i]); - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); - return true; -} - Array GatherNDCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -2382,25 +1303,6 @@ output shape will simply be (Y_0, ..., Y_{K-1}). // relay.sequence_mask TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs); -bool SequenceMaskRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // `types` contains: [data, valid_length, result] - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* valid_length = types[1].as(); - CHECK(data); - CHECK(valid_length); - const auto param = attrs.as(); - Array valid_length_shape; - CHECK(param->axis == 0 || param->axis == 1); - valid_length_shape.push_back(data->shape[1 - param->axis]); - reporter->Assign(types[1], TensorTypeNode::make(valid_length_shape, valid_length->dtype)); - reporter->Assign(types[2], types[0]); - return true; -} - Array SequenceMaskCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h new file mode 100644 index 0000000000000..d1eadd4e7ba46 --- /dev/null +++ b/src/relay/op/tensor/transform.h @@ -0,0 +1,1116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/op/tensor/transform.h + * \brief Tranform op attributes that can be shared among Relay and its dialects. + */ +#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ +#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +bool CastRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "cast: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + reporter->Assign(types[1], TensorTypeNode::make( + data->shape, param->dtype)); + return true; +} + +bool ExpandDimsRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "expand_dims: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + const int ndim = static_cast(data->shape.size()); + const int axis = param->axis; + const int num_newaxis = param->num_newaxis; + CHECK(num_newaxis >= 0) + << "expand_dims only accepts `num_newaxis >= 0`" + << ", but got num_newaxis = " << num_newaxis; + CHECK(-ndim - 1 <= axis && axis <= ndim) + << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis + << ", and data.ndim = " << ndim; + const int pivot = axis < 0 ? ndim + axis + 1 : axis; + std::vector oshape; + oshape.reserve(ndim + num_newaxis); + for (int i = 0; i < pivot; ++i) { + oshape.emplace_back(data->shape[i]); + } + for (int i = 0; i < num_newaxis; ++i) { + oshape.emplace_back(1); + } + for (int i = pivot; i < ndim; ++i) { + oshape.emplace_back(data->shape[i]); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +bool ConcatenateRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + /* If we receive a tuple we can continue, if we receive + * anything but an incomplete type we should signal an + * error. + */ + const auto* tensor_tuple = types[0].as(); + if (tensor_tuple == nullptr) { + throw relay::Error( + RELAY_ERROR( + "concatenate requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0]))); + } else if (types[0].as() != nullptr) { + return false; + } + + const auto* param = attrs.as(); + if (tensor_tuple->fields[0].as()) { + return false; + } + const auto& first = Downcast(tensor_tuple->fields[0]); + // Sanity check: ndim and dtype. + const int ndim = static_cast(first->shape.size()); + const DataType dtype = first->dtype; + + for (const Type& ele : tensor_tuple->fields) { + if (ele.as()) { + return false; + } + + const auto& e = Downcast(ele); + + int e_ndim = static_cast(e->shape.size()); + const DataType& e_dtype = e->dtype; + if (e_ndim != ndim) { + throw relay::Error("relay.concatenate requires all tensors have the same ndim"); + } + if (e_dtype != dtype) { + throw relay::Error("relay.concatenate requires all tensors have the same dtype"); + } + } + // Sanity check: axis + int axis = param->axis; + if (!(-ndim <= axis && axis < ndim)) { + throw relay::Error(RELAY_ERROR( + "concatenate only accepts `axis` in [-ndim, ndim)" << + ", but got axis = " << axis << + ", and ndim = " << ndim)); + } + axis = axis < 0 ? ndim + axis : axis; + // Calculate shape + std::vector oshape(first->shape.begin(), first->shape.end()); + IndexExpr &concat_dim = oshape[axis]; + bool has_any = false; + if (concat_dim.as()) { + has_any = true; + } else { + for (int i = 1; i < static_cast(tensor_tuple->fields.size()); ++i) { + const auto& e = Downcast(tensor_tuple->fields[i]); + if (e->shape[axis].as()) { + has_any = true; + break; + } + concat_dim += e->shape[axis]; + } + } + + if (has_any) { + concat_dim = Any::make(); + } + + auto rtype = TensorTypeNode::make(oshape, dtype); + reporter->Assign(types[1], rtype); + return true; +} + +bool StackRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + const auto* tensor_tuple = types[0].as(); + if (tensor_tuple == nullptr) { + CHECK(types[0].as()) + << "cast: expect input type to be TupleType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + const auto& first = Downcast(tensor_tuple->fields[0]); + // Sanity check: ndim and dtype. + const int ndim = static_cast(first->shape.size()); + const DataType dtype = first->dtype; + for (const Type& ele : tensor_tuple->fields) { + const auto& e = Downcast(ele); + int e_ndim = static_cast(e->shape.size()); + const DataType& e_dtype = e->dtype; + CHECK_EQ(e_ndim, ndim) << "relay.stack requires all tensors have the same ndim"; + CHECK_EQ(e_dtype, dtype) << "relay.stack requires all tensors have the same dtype"; + } + // Sanity check: axis + int axis = param->axis; + CHECK(-ndim <= axis && axis < ndim) + << "stack only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis + << ", and ndim = " << ndim; + axis = axis < 0 ? ndim + axis + 1 : axis; + // Calculate shape + std::vector oshape; + oshape.reserve(ndim + 1); + const int stack_dim = static_cast(tensor_tuple->fields.size()); + for (int i = 0; i < axis; ++i) { + oshape.emplace_back(first->shape[i]); + } + oshape.emplace_back(stack_dim); + for (int i = axis; i < ndim; ++i) { + oshape.emplace_back(first->shape[i]); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype)); + return true; +} + +bool TransposeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "transpose: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + const int ndim = data->shape.size(); + const Array& axes = param->axes; + // check dimension match + CHECK(!axes.defined() || static_cast(axes.size()) == ndim) + << "Dimension mismatch: axes has " << axes.size() << " elements" + << ", but data.ndim = " << ndim; + // construct int_axes + std::vector int_axes; + int_axes.reserve(ndim); + // used not defined to check if it is None. + if (!axes.defined()) { + for (int i = ndim - 1; i >= 0; --i) { + int_axes.push_back(i); + } + } else { + std::vector axis_used(ndim, 0); + for (const Integer& e : axes) { + int64_t axis = e; + // sanity check for axis and ndim + CHECK(-ndim <= axis && axis < ndim) + << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" + << ", but got axis = " << axis + << ", and data.ndim = " << ndim; + axis = axis < 0 ? axis + ndim : axis; + // sanity check for duplication + CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; + axis_used[axis] = 1; + int_axes.push_back(static_cast(axis)); + } + } + std::vector oshape; + oshape.reserve(ndim); + for (int axis : int_axes) { + oshape.push_back(data->shape[axis]); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +bool ReshapeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "reshape: expect input type to be TensorType but get " + << types[0]; + return false; + } + + const auto* param = attrs.as(); + Array data_shape; + Array newshape; + if (param->reverse) { + data_shape.assign(data->shape.rbegin(), data->shape.rend()); + newshape.assign(param->newshape.rbegin(), param->newshape.rend()); + } else { + data_shape = data->shape; + newshape = param->newshape; + } + Array oshape; + std::unordered_set used_input_dims; + std::unordered_set used_output_dims; + size_t src_idx = 0; + int infer_idx = -1; + + for (size_t i = 0; i < newshape.size(); ++i) { + int svalue = newshape[i]->value; + // special flag handling for shape inference. + if (svalue > 0) { + oshape.push_back(newshape[i]); + ++src_idx; + } else if (svalue == 0) { + // keep same + CHECK_LT(src_idx, data_shape.size()); + used_input_dims.insert(src_idx); + used_output_dims.insert(oshape.size()); + oshape.push_back(data_shape[src_idx++]); + } else if (svalue == -1) { + // inference based on rest + CHECK_LT(infer_idx, 0) + << "One and only one dim can be inferred"; + infer_idx = i; + oshape.push_back(1); + ++src_idx; + } else if (svalue == -2) { + // copy all remaining dims from source + while (src_idx < data_shape.size()) { + used_input_dims.insert(src_idx); + used_output_dims.insert(oshape.size()); + oshape.push_back(data_shape[src_idx++]); + } + } else if (svalue == -3) { + // merge two dims from source + CHECK_LT(src_idx + 1, data_shape.size()); + used_input_dims.insert(src_idx); + IndexExpr d1 = data_shape[src_idx++]; + used_input_dims.insert(src_idx); + IndexExpr d2 = data_shape[src_idx++]; + used_output_dims.insert(oshape.size()); + oshape.push_back(d1 * d2); + } else if (svalue == -4) { + // split the source dim s into two dims + // read the left dim and then the right dim (either can be -1) + CHECK_LT(i + 2, newshape.size()); + CHECK_LT(src_idx, data_shape.size()); + used_input_dims.insert(src_idx); + IndexExpr d0 = data_shape[src_idx++]; + Integer d1 = newshape[++i]; + Integer d2 = newshape[++i]; + if (d1->value == -1) { + CHECK(d2->value != -1) + << "Split dims cannot both be -1."; + used_output_dims.insert(oshape.size()); + if (d0.as()) { + oshape.push_back(Any::make()); + } else { + oshape.push_back(d0 / d2); + } + used_output_dims.insert(oshape.size()); + oshape.push_back(d2); + } else { + used_output_dims.insert(oshape.size()); + oshape.push_back(d1); + used_output_dims.insert(oshape.size()); + if (d2->value == -1) { + if (d0.as()) { + oshape.push_back(Any::make()); + } else { + oshape.push_back(d0 / d1); + } + } else { + oshape.push_back(d2); + } + } + } + } + + if (infer_idx >= 0) { + IndexExpr infer_dim = 1; + for (size_t i = 0; i < data_shape.size(); ++i) { + if (used_input_dims.count(i) != 0) { + continue; + } + if (data_shape[i].as()) { + infer_dim = Any::make(); + break; + } + infer_dim *= data_shape[i]; + } + if (!infer_dim.as()) { + for (size_t i = 0; i < oshape.size(); ++i) { + if (used_output_dims.count(i) != 0) { + continue; + } + if (oshape[i].as()) { + infer_dim = Any::make(); + break; + } + infer_dim /= oshape[i]; + } + } + oshape.Set(infer_idx, infer_dim); + } + + if (param->reverse) { + reporter->Assign(types[1], TensorTypeNode::make( + Array(oshape.rbegin(), oshape.rend()), data->dtype)); + } else { + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + } + return true; +} + +/*! +* \brief ReshapeLikeRel User defined type constraint function. +* \param num_inputs Number of input types in the args. +* \param attrs The additional attributes of the operator. +* \param reporter The reporter to report solution to. +* \return False if the relation has not been resolved, it might be resolved later. +* True if this relation has been resolved. +*/ +bool ReshapeLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* reshape_like = types[1].as(); + if (reshape_like == nullptr) { + return false; + } + CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) + << "Reshape inputs size should be compatible."; + reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype)); + return true; +} + +bool TakeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, result] + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + CHECK(data != nullptr); + const auto* indices = types[1].as(); + CHECK(indices != nullptr); + const auto param = attrs.as(); + CHECK(param != nullptr); + + if (!param->axis.defined()) { + std::vector oshape(indices->shape.begin(), indices->shape.end()); + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; + } + + std::vector oshape; + const auto ndim_data = static_cast(data->shape.size()); + const auto ndim_indices = static_cast(indices->shape.size()); + int axis = static_cast(param->axis->value); + if (axis < 0) axis += ndim_data; + CHECK_LE(axis, ndim_data) + << "axis should be with in data shape" + << ", but got = " << axis; + + oshape.reserve(ndim_data - 1 + ndim_indices); + for (int i = 0; i < axis; ++i) { + oshape.emplace_back(data->shape[i]); + } + for (int i = 0; i < ndim_indices; ++i) { + oshape.emplace_back(indices->shape[i]); + } + for (int i = axis+1; i < ndim_data; ++i) { + oshape.emplace_back(data->shape[i]); + } + + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +bool FullRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const InitOpAttrs* param = attrs.as(); + const auto* fill_value = types[0].as(); + if (fill_value == nullptr) { + return false; + } + + DataType out_dtype = param->dtype; + if (out_dtype.bits() == 0) { + out_dtype = fill_value->dtype; + } + + CHECK_EQ(fill_value->shape.size(), 0) + << "Fill value should be a scalar but has dimension " + << fill_value->shape.size() << "."; + + reporter->Assign(types[1], TensorTypeNode::make(param->shape, out_dtype)); + return true; +} + +bool InitOpRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 1); + const InitOpAttrs* param = attrs.as(); + + reporter->Assign(types[0], TensorTypeNode::make(param->shape, param->dtype)); + return true; +} + +bool FullLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* fill_value = types[1].as(); + if (fill_value == nullptr) { + return false; + } + + CHECK_EQ(fill_value->shape.size(), 0) + << "The fill value should be a scalar but here it has dimension " + << fill_value->shape.size() << "."; + + reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype)); + return true; +} + +double ToScalar(const runtime::NDArray& array) { + if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) { + return reinterpret_cast(array->data)[0]; + } else { + return reinterpret_cast(array->data)[0]; + } +} + +bool ArangeRel(const Array& types, + int num_inputs, + const Attrs& raw_attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + const ArangeAttrs* attrs = raw_attrs.as(); + const ConstantNode *cstart, *cstop, *cstep; + + reporter->Assign(types[0], types[1]); + reporter->Assign(types[1], types[2]); + reporter->Assign(types[2], TensorTypeNode::make({}, attrs->dtype)); + + if ((cstart = attrs->start.as()) && + (cstop = attrs->stop.as()) && + (cstep = attrs->step.as())) { + double start = ToScalar(cstart->data); + double stop = ToScalar(cstop->data); + double step = ToScalar(cstep->data); + int32_t num_elem = static_cast(std::ceil((stop - start) / step)); + CHECK_GT(num_elem, 0) + << "Invalid arange attributes (start, stop, step): " << attrs->start + << ", " << attrs->stop << ", " << attrs->step; + reporter->Assign(types[3], TensorTypeNode::make({num_elem}, attrs->dtype)); + return true; + } else { + reporter->Assign(types[3], TensorTypeNode::make({Any::make()}, attrs->dtype)); + return true; + } +} + +bool RepeatRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "repeat: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + const int ndim = static_cast(data->shape.size()); + const int repeats = param->repeats; + const int axis = param->axis; + CHECK(repeats >= 1) + << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; + CHECK(-ndim - 1 <= axis && axis <= ndim) + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis + << ", and data.ndim = " << ndim; + const int pivot = axis < 0 ? ndim + axis : axis; + std::vector oshape; + oshape.reserve(ndim + repeats); + for (int i = 0; i < pivot; ++i) { + oshape.emplace_back(data->shape[i]); + } + oshape.emplace_back(data->shape[pivot] * repeats); + for (int i = pivot + 1; i < ndim; ++i) { + oshape.emplace_back(data->shape[i]); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +bool TileRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "tile: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + const size_t ndim = data->shape.size(); + const Array& reps = param->reps; + // check dimension match + CHECK(reps.defined()) + << "repetition array is not defined. data.ndim = " << ndim; + const size_t rndim = reps.size(); + for (size_t i = 0; i < rndim; ++i) { + if (const tvm::ir::IntImm* val = reps[i].as()) { + CHECK_GT(val->value, 0) + << "Tile reps value should always be larger than 0, but get: " << val->value; + } + } + size_t tndim = (ndim > rndim) ? ndim : rndim; + // re-construct data shape or reps shape + std::vector data_shape; + std::vector reps_shape; + data_shape.reserve(tndim); + reps_shape.reserve(tndim); + if (ndim == rndim) { + for (size_t i = 0; i < tndim; ++i) { + data_shape.emplace_back(data->shape[i]); + reps_shape.emplace_back(reps[i]); + } + } else if (ndim > rndim) { + for (size_t i = 0; i < ndim; ++i) + data_shape.emplace_back(data->shape[i]); + for (size_t i = 0; i < (ndim - rndim); ++i) + reps_shape.emplace_back(1); + for (size_t i = 0; i < rndim; ++i) + reps_shape.emplace_back(reps[i]); + } else { + for (size_t i = 0; i < rndim; ++i) + reps_shape.emplace_back(reps[i]); + for (size_t i = 0; i < (rndim - ndim); ++i) + data_shape.emplace_back(1); + for (size_t i = 0; i < ndim; ++i) + data_shape.emplace_back(data->shape[i]); + } + std::vector oshape; + oshape.reserve(tndim); + for (size_t i = 0; i < tndim; ++i) { + oshape.emplace_back(data_shape[i] * reps_shape[i]); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +bool ReverseRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "reverse: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + const int ndim = static_cast(data->shape.size()); + const int axis = param->axis; + CHECK(-ndim <= axis && axis < ndim) + << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" + << ", but got axis = " << axis + << ", and data.ndim = " << ndim; + reporter->Assign(types[1], types[0]); + return true; +} + +bool WhereRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4U); + const auto* condition = types[0].as(); + const auto* x = types[1].as(); + const auto* y = types[2].as(); + CHECK(condition != nullptr && x != nullptr && y != nullptr); + + const auto& cond_shape = condition->shape; + const auto& x_shape = x->shape; + const auto& y_shape = y->shape; + CHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size"; + + if (cond_shape.size() != x_shape.size()) { + CHECK_EQ(cond_shape.size(), 1) + << "Shape of condition " << condition->shape + << " must be either equal to x or has dimension of 1."; + } + for (size_t i = 0; i < x_shape.size(); i++) { + CHECK(reporter->AssertEQ(x_shape[i], y_shape[i])) + << "x and y must have the same shape: " << x_shape << " vs " << y_shape; + + if (i < cond_shape.size()) { + CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) + << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; + } + } + reporter->Assign(types[3], TensorTypeNode::make(x_shape, x->dtype)); + return true; +} + +bool SqueezeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* param = attrs.as(); + CHECK(param != nullptr); + std::vector result_shape; + // if axes is None, squeeze all axes of dimension 1 + if (!param->axis.defined()) { + for (const auto& e : data->shape) { + const int64_t* axis_ptr = as_const_int(e); + CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete"; + if (*axis_ptr != 1) { + result_shape.push_back(e); + } + } + } else { + // pair up original shape with a boolean which control whether it will be in the final shape. + std::vector > original_shape; + for (const auto& e : data->shape) { + original_shape.push_back(std::pair(e, true)); + } + for (const auto& e : param->axis) { + int64_t axis_val = e->value; + if (axis_val < 0) { + axis_val += static_cast(original_shape.size()); + } + CHECK_GE(axis_val, 0); + CHECK_LT(axis_val, original_shape.size()); + original_shape.at(axis_val).second = false; + } + for (const auto p : original_shape) { + if (p.second) { + result_shape.push_back(p.first); + } else { + const int64_t* axis_ptr = as_const_int(p.first); + CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor"; + CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1"; + } + } + } + reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype)); + return true; +} + +// Have no idea how to assert the constraint. +// CollapseSumLike: -> B where BroadCast(A, B) = A +bool CollapseSumLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + reporter->Assign(types[2], types[1]); + return true; +} + +// BroadCastTo: -> B where BroadCast(A, B) = B +bool BroadCastToRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + auto ioattrs = attrs.as(); + CHECK(ioattrs); + auto intt = types[0].as(); + if (intt == nullptr) { return false; } + auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype); + reporter->Assign(types[1], type); + return true; +} + +// BroadCastToLike: -> B where BroadCast(A, B) = B +bool BroadCastToLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + reporter->Assign(types[2], types[1]); + return true; +} + +bool StridedSliceRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const StridedSliceAttrs *param = attrs.as(); + CHECK(param != nullptr); + + auto dshape = data->shape; + auto num_axis = dshape.size(); + + std::vector stride_vec; + for (Integer i : param->strides) { + CHECK(i.defined()); + stride_vec.push_back(i->value); + } + for (size_t i = stride_vec.size(); i < num_axis; ++i) { + stride_vec.push_back(1); + } + const int64_t max_range = std::numeric_limits::max(); + + std::vector begin_vec; + for (size_t i = 0; i < param->begin.size(); ++i) { + if (!param->begin[i].defined()) { + // value=None + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(param->begin[i]->value); + } + } + for (size_t i = begin_vec.size(); i < num_axis; ++i) { + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } + + std::vector end_vec; + for (size_t i = 0; i < param->end.size(); ++i) { + // allow end to be None + if (!param->end[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(param->end[i]->value); + } + } + for (size_t i = end_vec.size(); i < num_axis; ++i) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } + + std::vector oshape(dshape.size()); + for (size_t i = 0; i < num_axis; ++i) { + int64_t stride_v = stride_vec[i]; + int64_t begin_v = begin_vec[i]; + int64_t end_v = end_vec[i]; + + if ((stride_v == 1 && + begin_v == 0 && + end_v == max_range) || + (stride_v == -1 && + begin_v == max_range && + end_v == 0)) { + // Quick path, do not slice this dimension. + oshape[i] = dshape[i]; + continue; + } + // Normal path, require the shape to be concrete integer. + // Require concrete integer as symbolic inference of min/max + // can get complicated and not very helpful. + const int64_t* p_dim_size = as_const_int(dshape[i]); + CHECK(p_dim_size) + << "strided_slice requires sliced dimension to be concrete int"; + int64_t dim_size = p_dim_size[0]; + begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; + end_v = (end_v < 0) ? dim_size + end_v : end_v; + + int64_t slice_range, step; + if (stride_v < 0) { + if (end_v < -1) end_v = -1; + CHECK_LT(end_v, begin_v) + << "strided_slice get empty slice at axis " << i; + begin_v = std::min(dim_size - 1, begin_v); + slice_range = begin_v - end_v; + step = -stride_v; + } else { + if (begin_v < 0) begin_v = 0; + CHECK_GE(stride_v, 0); + CHECK_LT(begin_v, end_v) + << "strided_slice get empty slice at axis " << i; + end_v = std::min(dim_size, end_v); + slice_range = end_v - begin_v; + step = stride_v; + } + oshape[i] = make_const(dshape[i].type(), (slice_range + step - 1) / step); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +bool SplitRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; + const auto param = attrs.as(); + CHECK(param != nullptr); + auto axis = param->axis; + if (axis < 0) { + axis += data->shape.size(); + } + CHECK_LT(axis, data->shape.size()) + << "axis should be within the input dimension range."; + CHECK_GE(axis, 0) + << "axis should be within the input dimension range."; + + if (const IntImm* sections = param->indices_or_sections.as()) { + CHECK(reporter->Assert(data->shape[axis] % + sections->value == make_zero(Int(64)))) + << "indices_or_sections need to be able to divide input.shape[axis]"; + std::vector fields; + for (int i = 0; i < sections->value; ++i) { + std::vector oshape(data->shape.begin(), data->shape.end()); + oshape[axis] /= int32_t(sections->value); + auto vec_type = TensorTypeNode::make(oshape, data->dtype); + fields.push_back(vec_type); + } + reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + } else { + auto indices = param->indices_or_sections.as()->data; + auto begin = IndexExpr(make_zero(Int(32))); + std::vector fields; + for (unsigned int i = 0; i < indices.size(); ++i) { + CHECK(reporter->Assert(IndexExpr(indices[i]) > begin)) + << "indices_or_sections need to be a sorted ascending list"; + std::vector oshape(data->shape.begin(), data->shape.end()); + oshape[axis] = IndexExpr(indices[i]) - begin; + begin = IndexExpr(indices[i]); + auto vec_type = TensorTypeNode::make(oshape, data->dtype); + fields.push_back(vec_type); + } + CHECK(reporter->Assert(begin < data->shape[axis])) + << "The sum of sections must match the input.shape[axis]"; + std::vector oshape(data->shape.begin(), data->shape.end()); + oshape[axis] = data->shape[axis] - begin; + auto vec_type = TensorTypeNode::make(oshape, data->dtype); + fields.push_back(vec_type); + reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + } + return true; +} + +/*! +* \brief SliceLikeRel User defined type constraint function. +* \param num_inputs Number of input types in the args. +* \param attrs The additional attributes of the operator. +* \param reporter The reporter to report solution to. +* \return False if the relation has not been resolved, it might be resolved later. +* True if this relation has been resolved. +*/ +bool SliceLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + + const auto* target = types[1].as(); + if (target == nullptr) { + return false; + } + + const auto param = attrs.as(); + CHECK(param != nullptr); + + const Array& dshape = data->shape; + const Array& target_shape = target->shape; + std::vector oshape(dshape.begin(), dshape.end()); + + if (!param->axes.defined()) { + for (size_t i = 0; i < dshape.size(); ++i) { + if (i < target_shape.size()) { + oshape[i] = target_shape[i]; + CHECK(reporter->Assert(oshape[i] <= dshape[i])) + << "End index of axis " << i << " exceeds input shape: " + << oshape[i] << " vs " << dshape[i]; + } + } + } else { + CHECK(param->axes.size() != 0) << "Axes cannot be empty."; + for (Integer val : param->axes) { + int axis = val->value; + if (axis < 0) { + axis += dshape.size(); + } + CHECK(axis < static_cast(target_shape.size())) + << "Axis " << axis << " exceeds dimension " + << target_shape.size() << " of target_shape."; + oshape[axis] = target_shape[axis]; + CHECK(reporter->Assert(oshape[axis] <= dshape[axis])) + << "End index of axis " << axis << " exceeds input shape: " + << oshape[axis] << " vs " << dshape[axis]; + } + } + + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +bool LayoutTransformRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + const auto* data = types[0].as(); + CHECK(data != nullptr); + const LayoutTransformAttrs* params = attrs.as(); + + Layout src_layout(params->src_layout); + Layout dst_layout(params->dst_layout); + + CHECK(src_layout.defined() && dst_layout.defined()) + << "cannot convert from/to undefined layout"; + + auto layout_converter = BijectiveLayoutNode::make(src_layout, dst_layout); + CHECK(layout_converter.defined()) + << "cannot convert from " << params->src_layout << " to " << params->dst_layout; + + const auto& out_shape = layout_converter.ForwardShape(data->shape); + reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype)); + return true; +} + +// gather_nd operator +bool GatherNDRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, result] + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* indices = types[1].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "GatherND: expect input data type to be TensorType but get " + << types[0]; + return false; + } + if (indices == nullptr) { + CHECK(types[1].as()) + << "GatherND: expect indices type to be TensorType but get " + << types[1]; + return false; + } + const size_t ndim = data->shape.size(); + const IntImm* mdim = indices->shape[0].as(); + const size_t kdim = indices->shape.size() - 1; + CHECK(size_t(mdim->value) <= ndim) + << "GatherND: indices shape does satisfy."; + + Array oshape; + for (size_t i = 1; i < kdim + 1; ++i) + oshape.push_back(indices->shape[i]); + for (size_t i = mdim->value; i < ndim; ++i) + oshape.push_back(data->shape[i]); + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +bool SequenceMaskRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, valid_length, result] + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* valid_length = types[1].as(); + CHECK(data); + CHECK(valid_length); + const auto param = attrs.as(); + Array valid_length_shape; + CHECK(param->axis == 0 || param->axis == 1); + valid_length_shape.push_back(data->shape[1 - param->axis]); + reporter->Assign(types[1], TensorTypeNode::make(valid_length_shape, valid_length->dtype)); + reporter->Assign(types[2], types[0]); + return true; +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_