diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 5eb54a1325ba..2f59fb9db19c 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -29,6 +29,7 @@ #include #include "../../pass/alter_op_layout.h" +#include "convolution.h" namespace tvm { namespace relay { @@ -36,111 +37,6 @@ namespace relay { // relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); -bool Conv2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - static const Layout kNCHW("NCHW"); - static const Layout kOIHW("OIHW"); - - const Conv2DAttrs* param = attrs.as(); - CHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); - CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; - - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW); - CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW); - CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCHW." - << " But got " << out_layout; - - Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - CHECK_EQ(param->kernel_size.size(), 2); - CHECK_EQ(param->dilation.size(), 2); - Array wshape; - - if (tvm::ir::Equal(param->channels, param->groups)) { - // infer weight's shape for depthwise convolution - wshape = { - {dshape_nchw[1], - param->groups / dshape_nchw[1], - param->kernel_size[0], - param->kernel_size[1]}}; - } else { - wshape = { - {param->channels, - dshape_nchw[1] / param->groups, - param->kernel_size[0], - param->kernel_size[1]}}; - } - - wshape = trans_kernel_layout.BackwardShape(wshape); - channels = param->channels; - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - DataType weight_dtype = data->dtype; - if (weight != nullptr) { - weight_dtype = weight->dtype; - } - // assign result to reporter - reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); - if (param->kernel_size.defined()) { - CHECK_EQ(param->kernel_size.size(), 2); - // check the size - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && - reporter->AssertEQ(param->kernel_size[1], wshape[3])) - << "Conv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << wshape; - } - if (param->channels.defined()) { - CHECK(reporter->AssertEQ(param->channels, wshape[0])) - << "Conv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << wshape; - } - CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1])); - channels = wshape[0]; - dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; - } - // dilation - Array oshape({dshape_nchw[0], channels, 0, 0}); - - oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1); - oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1); - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); - return true; -} - template Array > Conv2DInferCorrectLayout( const Attrs& attrs, @@ -208,7 +104,7 @@ with the layer input to produce a tensor of outputs. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); @@ -770,7 +666,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h new file mode 100644 index 000000000000..fb5844749117 --- /dev/null +++ b/src/relay/op/nn/convolution.h @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/op/nn/convolution.h + * \brief Properties def of convlution operator for sharing. + */ +#ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_ +#define TVM_RELAY_OP_NN_CONVOLUTION_H_ + +#include +#include + +namespace tvm { +namespace relay { + +template +bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + + const AttrType* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); + CHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW); + CHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW); + CHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; + + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + CHECK_EQ(param->kernel_size.size(), 2); + CHECK_EQ(param->dilation.size(), 2); + Array wshape; + + if (tvm::ir::Equal(param->channels, param->groups)) { + // infer weight's shape for depthwise convolution + wshape = {{dshape_nchw[1], param->groups / dshape_nchw[1], param->kernel_size[0], + param->kernel_size[1]}}; + } else { + wshape = {{param->channels, dshape_nchw[1] / param->groups, param->kernel_size[0], + param->kernel_size[1]}}; + } + + wshape = trans_kernel_layout.BackwardShape(wshape); + channels = param->channels; + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } + // assign result to reporter + reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + CHECK_EQ(param->kernel_size.size(), 2); + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3])) + << "Conv2D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << wshape; + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[0])) + << "Conv2D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << wshape; + } + CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1])); + channels = wshape[0]; + dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + } + // dilation + Array oshape({dshape_nchw[0], channels, 0, 0}); + + oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1); + oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1); + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_ diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 2c03bbac70d7..42a0f011d668 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -35,6 +35,7 @@ #include "../type_relations.h" #include "../../pass/alter_op_layout.h" #include "../op_common.h" +#include "nn.h" namespace tvm { namespace relay { @@ -102,45 +103,6 @@ RELAY_REGISTER_OP("nn.bias_add") // relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); - -bool DenseRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - - const DenseAttrs* param = attrs.as(); - CHECK(param != nullptr); - - CHECK(static_cast(data->shape.size()) != 0); - - Array oshape = data->shape; - if (param->units.defined()) { - Array dshape = data->shape; - // validate the weight shape is proper if defined - // Assign weight type - Array wshape({param->units, dshape[dshape.size() - 1]}); - reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); - oshape.Set((oshape.size() - 1), param->units); - } else { - if (weight == nullptr) return false; - Array wshape = weight->shape; - oshape.Set((oshape.size() - 1), wshape[0]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); - return true; -} - - // Positional relay function to create dense operator used by frontend FFI. Expr MakeDense(Expr data, Expr weight, @@ -171,7 +133,7 @@ RELAY_REGISTER_OP("nn.dense") .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "2D Tensor", "Weight matrix.") .set_support_level(1) -.add_type_rel("Dense", DenseRel); +.add_type_rel("Dense", DenseRel); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h new file mode 100644 index 000000000000..2c65d2526437 --- /dev/null +++ b/src/relay/op/nn/nn.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/op/nn/nn.h + * \brief Properties def of nn operators for sharing. + */ +#ifndef TVM_RELAY_OP_NN_NN_H_ +#define TVM_RELAY_OP_NN_NN_H_ + +#include + +namespace tvm { +namespace relay { + +template +bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + const AttrType* param = attrs.as(); + CHECK(param != nullptr); + + CHECK(static_cast(data->shape.size()) != 0); + + Array oshape = data->shape; + if (param->units.defined()) { + Array dshape = data->shape; + // validate the weight shape is proper if defined + // Assign weight type + Array wshape({param->units, dshape[dshape.size() - 1]}); + reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + oshape.Set((oshape.size() - 1), param->units); + } else { + if (weight == nullptr) return false; + Array wshape = weight->shape; + oshape.Set((oshape.size() - 1), wshape[0]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_NN_NN_H_ diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b39c282b1d96..c3975c3a2808 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -37,6 +37,7 @@ #include "../op_common.h" #include "../../../arithmetic/compute_expr.h" #include "../../pass/alter_op_layout.h" +#include "transform.h" namespace tvm { namespace relay { @@ -210,86 +211,6 @@ RELAY_REGISTER_OP("expand_dims") // relay.concatenate TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); -bool ConcatenateRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // types: [data, result] - CHECK_EQ(types.size(), 2); - /* If we receive a tuple we can continue, if we receive - * anything but an incomplete type we should signal an - * error. - */ - const auto* tensor_tuple = types[0].as(); - if (tensor_tuple == nullptr) { - throw relay::Error( - RELAY_ERROR( - "concatenate requires a tuple of tensors as the first argument, found " - << PrettyPrint(types[0]))); - } else if (types[0].as() != nullptr) { - return false; - } - - const auto* param = attrs.as(); - if (tensor_tuple->fields[0].as()) { - return false; - } - const auto& first = Downcast(tensor_tuple->fields[0]); - // Sanity check: ndim and dtype. - const int ndim = static_cast(first->shape.size()); - const DataType dtype = first->dtype; - - for (const Type& ele : tensor_tuple->fields) { - if (ele.as()) { - return false; - } - - const auto& e = Downcast(ele); - - int e_ndim = static_cast(e->shape.size()); - const DataType& e_dtype = e->dtype; - if (e_ndim != ndim) { - throw relay::Error("relay.concatenate requires all tensors have the same ndim"); - } - if (e_dtype != dtype) { - throw relay::Error("relay.concatenate requires all tensors have the same dtype"); - } - } - // Sanity check: axis - int axis = param->axis; - if (!(-ndim <= axis && axis < ndim)) { - throw relay::Error(RELAY_ERROR( - "concatenate only accepts `axis` in [-ndim, ndim)" << - ", but got axis = " << axis << - ", and ndim = " << ndim)); - } - axis = axis < 0 ? ndim + axis : axis; - // Calculate shape - std::vector oshape(first->shape.begin(), first->shape.end()); - IndexExpr &concat_dim = oshape[axis]; - bool has_any = false; - if (concat_dim.as()) { - has_any = true; - } else { - for (int i = 1; i < static_cast(tensor_tuple->fields.size()); ++i) { - const auto& e = Downcast(tensor_tuple->fields[i]); - if (e->shape[axis].as()) { - has_any = true; - break; - } - concat_dim += e->shape[axis]; - } - } - - if (has_any) { - concat_dim = Any::make(); - } - - auto rtype = TensorTypeNode::make(oshape, dtype); - reporter->Assign(types[1], rtype); - return true; -} - Array ConcatenateCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -358,7 +279,7 @@ RELAY_REGISTER_OP("concatenate") .set_num_inputs(1) .add_argument("data", "Tensor", "The input list of tensors.") .set_support_level(1) -.add_type_rel("Concatenate", ConcatenateRel) +.add_type_rel("Concatenate", ConcatenateRel) .set_attr("FInferCorrectLayout", ConcatenateLayout) .set_attr("FTVMCompute", ConcatenateCompute) .set_attr("TOpPattern", kInjective); diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h new file mode 100644 index 000000000000..3a4d50bb8dce --- /dev/null +++ b/src/relay/op/tensor/transform.h @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/op/tensor/transform.h + * \brief Transform op attributes that can be shared among Relay and its dialects. + */ +#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ +#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +template +bool ConcatenateRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + /* If we receive a tuple we can continue, if we receive + * anything but an incomplete type we should signal an + * error. + */ + const auto* tensor_tuple = types[0].as(); + if (tensor_tuple == nullptr) { + throw relay::Error( + RELAY_ERROR( + "concatenate requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0]))); + } else if (types[0].as() != nullptr) { + return false; + } + + const auto* param = attrs.as(); + if (tensor_tuple->fields[0].as()) { + return false; + } + const auto& first = Downcast(tensor_tuple->fields[0]); + // Sanity check: ndim and dtype. + const int ndim = static_cast(first->shape.size()); + const DataType dtype = first->dtype; + + for (const Type& ele : tensor_tuple->fields) { + if (ele.as()) { + return false; + } + + const auto& e = Downcast(ele); + + int e_ndim = static_cast(e->shape.size()); + const DataType& e_dtype = e->dtype; + if (e_ndim != ndim) { + throw relay::Error("relay.concatenate requires all tensors have the same ndim"); + } + if (e_dtype != dtype) { + throw relay::Error("relay.concatenate requires all tensors have the same dtype"); + } + } + // Sanity check: axis + int axis = param->axis; + if (!(-ndim <= axis && axis < ndim)) { + throw relay::Error(RELAY_ERROR( + "concatenate only accepts `axis` in [-ndim, ndim)" << + ", but got axis = " << axis << + ", and ndim = " << ndim)); + } + axis = axis < 0 ? ndim + axis : axis; + // Calculate shape + std::vector oshape(first->shape.begin(), first->shape.end()); + IndexExpr &concat_dim = oshape[axis]; + bool has_any = false; + if (concat_dim.as()) { + has_any = true; + } else { + for (int i = 1; i < static_cast(tensor_tuple->fields.size()); ++i) { + const auto& e = Downcast(tensor_tuple->fields[i]); + if (e->shape[axis].as()) { + has_any = true; + break; + } + concat_dim += e->shape[axis]; + } + } + + if (has_any) { + concat_dim = Any::make(); + } + + auto rtype = TensorTypeNode::make(oshape, dtype); + reporter->Assign(types[1], rtype); + return true; +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_