From 0950905531e03c3e241a39526fc92da483910ef2 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 13 Jul 2020 20:42:54 -0700 Subject: [PATCH] Refactor to expose MakeOp functions to C++ (#6047) * Initial Refactor * add templated nn Make* functions * fix build typo * inline functions, fix unit tests --- python/tvm/relay/op/algorithm.py | 4 +- src/relay/op/make_op.h | 83 ++++++++++++ src/relay/op/nn/convolution.cc | 107 +--------------- src/relay/op/nn/convolution_make.h | 149 ++++++++++++++++++++++ src/relay/op/nn/nn.cc | 1 + src/relay/op/nn/pad.cc | 1 + src/relay/op/nn/pooling.cc | 30 +---- src/relay/op/nn/pooling.h | 65 ++++++++++ src/relay/op/tensor/reduce.cc | 17 ++- src/relay/op/tensor/transform.cc | 1 + src/relay/op/tensor/transform.h | 4 +- src/relay/op/tensor/unary.cc | 7 +- src/relay/transforms/dynamic_to_static.cc | 38 ++---- src/relay/transforms/pattern_util.h | 111 +++------------- 14 files changed, 347 insertions(+), 271 deletions(-) create mode 100644 src/relay/op/make_op.h create mode 100644 src/relay/op/nn/convolution_make.h create mode 100644 src/relay/op/nn/pooling.h diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 5aeb7e647b4e..f3c35b8bad03 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -16,7 +16,7 @@ # under the License. """Classic algorithm operation""" from __future__ import absolute_import as _abs -import numpy as np + from . import _make from .dyn import _make as _dyn_make from ..expr import TupleWrapper, Expr, Constant @@ -85,7 +85,7 @@ def topk(data, k=1, axis=-1, ret_type="both", The computed result. """ if isinstance(k, Constant): - k = np.asscalar(k.data.asnumpy()) + k = k.data.asnumpy().item() if isinstance(k, Expr): out = _dyn_make.topk(data, k, axis, ret_type, is_ascend, dtype) else: diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h new file mode 100644 index 000000000000..b5c7a526c658 --- /dev/null +++ b/src/relay/op/make_op.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file tvm/relay/op/make_op.h + * \brief Header of internal operator functions + * to assist in creating ops in C++ + */ +#ifndef TVM_RELAY_OP_MAKE_OP_H_ +#define TVM_RELAY_OP_MAKE_OP_H_ + +#include +#include + +// Include Templated Make Functions +#include "nn/convolution_make.h" +#include "nn/pooling.h" + +namespace tvm { +namespace relay { + +Expr MakeBroadCastTo(Expr data, Expr shape); + +Expr MakeCast(Expr data, DataType dtype); + +Expr MakeClip(Expr a, double a_min, double a_max); + +Expr MakeConcatenate(Expr data, int axis); + +Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype); + +Expr MakeExpandDims(Expr data, int axis, int num_newaxis); + +Expr MakeFull(Expr fill_value, Expr shape, DataType dtype); + +Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); + +Expr MakeOnes(Expr shape, DataType dtype); + +Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode); + +Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, String op_name); + +Expr MakeRepeat(Expr data, int repeats, int axis); + +Expr MakeReshape(Expr data, Array newshape); + +Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); + +Expr MakeSqueeze(Expr data, Array axis); + +Expr MakeStack(Expr data, int axis); + +Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode); + +Expr MakeTile(Expr data, Array reps); + +Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype); + +Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude); + +Expr MakeZeros(Expr shape, DataType dtype); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_MAKE_OP_H_ diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index f63c48915f25..438500f45e5e 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -31,72 +31,11 @@ #include "../../transforms/infer_layout_util.h" #include "../op_common.h" +#include "convolution_make.h" namespace tvm { namespace relay { -template -Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, IndexExpr channels, - Array kernel_size, std::string data_layout, std::string kernel_layout, - std::string out_layout, DataType out_dtype, std::string op_name) { - auto attrs = make_object(); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = std::move(channels); - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - const Op& op = Op::Get(op_name); - return Call(op, {data, weight}, Attrs(attrs), {}); -} - -template -Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array strides, - Array padding, Array dilation, int groups, - IndexExpr channels, Array kernel_size, std::string data_layout, - std::string kernel_layout, std::string out_layout, DataType out_dtype, - std::string op_name) { - auto attrs = make_object(); - attrs->tile_size = tile_size; - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = std::move(channels); - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - const Op& op = Op::Get(op_name); - return Call(op, {data, weight}, Attrs(attrs), {}); -} - -template -Expr MakeConvGemm(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, IndexExpr channels, - Array kernel_size, std::string data_layout, std::string kernel_layout, - std::string out_layout, DataType out_dtype, std::string op_name) { - auto attrs = make_object(); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = std::move(channels); - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - const Op& op = Op::Get(op_name); - return Call(op, {data, weight}, Attrs(attrs), {}); -} - Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) { auto attrs = make_object(); attrs->tile_size = tile_size; @@ -112,50 +51,6 @@ Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std: return Call(op, {weight}, Attrs(attrs), {}); } -template -Expr MakeConvTranspose(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, IndexExpr channels, - Array kernel_size, std::string data_layout, - std::string kernel_layout, std::string out_layout, - Array output_padding, DataType out_dtype, std::string op_name) { - auto attrs = make_object(); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = std::move(channels); - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->output_padding = std::move(output_padding); - attrs->out_dtype = std::move(out_dtype); - const Op& op = Op::Get(op_name); - return Call(op, {data, weight}, Attrs(attrs), {}); -} - -template -Expr MakeDeformableConv(Expr data, Expr offset, Expr weight, Array strides, - Array padding, Array dilation, int deformable_groups, - int groups, int channels, Array kernel_size, - std::string data_layout, std::string kernel_layout, std::string out_layout, - DataType out_dtype, std::string op_name) { - auto attrs = make_object(); - attrs->strides = strides; - attrs->padding = padding; - attrs->dilation = dilation; - attrs->deformable_groups = deformable_groups; - attrs->groups = groups; - attrs->channels = channels; - attrs->kernel_size = kernel_size; - attrs->data_layout = data_layout; - attrs->kernel_layout = kernel_layout; - attrs->out_layout = out_layout; - attrs->out_dtype = out_dtype; - const Op& op = Op::Get(op_name); - return Call(op, {data, offset, weight}, Attrs{attrs}, {}); -} - // relay.nn.conv1d TVM_REGISTER_NODE_TYPE(Conv1DAttrs); diff --git a/src/relay/op/nn/convolution_make.h b/src/relay/op/nn/convolution_make.h new file mode 100644 index 000000000000..01d6f183f79e --- /dev/null +++ b/src/relay/op/nn/convolution_make.h @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/nn/make_convolution.h + * \brief utilities for creating convolution ops + */ +#ifndef TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ +#define TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relay { + +template +inline Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype, + std::string op_name) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +template +inline Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, + std::string data_layout, std::string kernel_layout, + std::string out_layout, DataType out_dtype, std::string op_name) { + auto attrs = make_object(); + attrs->tile_size = tile_size; + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +template +inline Expr MakeConvGemm(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype, + std::string op_name) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +template +inline Expr MakeConvTranspose(Expr data, Expr weight, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, + std::string data_layout, std::string kernel_layout, + std::string out_layout, Array output_padding, + DataType out_dtype, std::string op_name) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->output_padding = std::move(output_padding); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +template +inline Expr MakeDeformableConv(Expr data, Expr offset, Expr weight, Array strides, + Array padding, Array dilation, + int deformable_groups, int groups, int channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, + DataType out_dtype, std::string op_name) { + auto attrs = make_object(); + attrs->strides = strides; + attrs->padding = padding; + attrs->dilation = dilation; + attrs->deformable_groups = deformable_groups; + attrs->groups = groups; + attrs->channels = channels; + attrs->kernel_size = kernel_size; + attrs->data_layout = data_layout; + attrs->kernel_layout = kernel_layout; + attrs->out_layout = out_layout; + attrs->out_dtype = out_dtype; + const Op& op = Op::Get(op_name); + return Call(op, {data, offset, weight}, Attrs{attrs}, {}); +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d65fc27472c0..7013c02fde20 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -37,6 +37,7 @@ #include #include "../../transforms/infer_layout_util.h" +#include "../make_op.h" #include "../op_common.h" #include "../type_relations.h" diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index aba87e2017a0..52259c535125 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -29,6 +29,7 @@ #include +#include "../make_op.h" #include "../op_common.h" namespace tvm { diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index e54a5f32fc88..63f0ce539d82 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -21,6 +21,8 @@ * \file pooling.cc * \brief Pooling operators */ +#include "pooling.h" + #include #include #include @@ -56,34 +58,6 @@ Array > PoolInferCorrectLayout(const Attrs& attrs, return Array >{{inferred_layout}, {inferred_layout}}; } -template -Expr MakeMaxPool(Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode, String op_name) { - auto attrs = make_object(); - attrs->pool_size = std::move(pool_size); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->layout = std::move(layout); - attrs->ceil_mode = ceil_mode; - static const Op& op = Op::Get(op_name); - return Call(op, {data}, Attrs(attrs), {}); -} - -template -Expr MakeAvgPool(Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode, bool count_include_pad, - String op_name) { - auto attrs = make_object(); - attrs->pool_size = std::move(pool_size); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->layout = std::move(layout); - attrs->ceil_mode = ceil_mode; - attrs->count_include_pad = count_include_pad; - static const Op& op = Op::Get(op_name); - return Call(op, {data}, Attrs(attrs), {}); -} - template bool Pool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { diff --git a/src/relay/op/nn/pooling.h b/src/relay/op/nn/pooling.h new file mode 100644 index 000000000000..a803698b93eb --- /dev/null +++ b/src/relay/op/nn/pooling.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/nn/convolution.h + * \brief Properties def of convlution operator for sharing. + */ +#ifndef TVM_RELAY_OP_NN_POOLING_H_ +#define TVM_RELAY_OP_NN_POOLING_H_ + +#include +#include + +#include + +namespace tvm { +namespace relay { + +template +inline Expr MakeMaxPool(Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode, String op_name) { + auto attrs = make_object(); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->layout = std::move(layout); + attrs->ceil_mode = ceil_mode; + static const Op& op = Op::Get(op_name); + return Call(op, {data}, Attrs(attrs), {}); +} + +template +inline Expr MakeAvgPool(Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode, + bool count_include_pad, String op_name) { + auto attrs = make_object(); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->layout = std::move(layout); + attrs->ceil_mode = ceil_mode; + attrs->count_include_pad = count_include_pad; + static const Op& op = Op::Get(op_name); + return Call(op, {data}, Attrs(attrs), {}); +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_NN_POOLING_H_ diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index d526cef5bf62..733d6e9448fd 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -30,6 +30,7 @@ #include #include +#include "../make_op.h" #include "../op_common.h" #include "../type_relations.h" @@ -293,15 +294,19 @@ bool ReduceRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } +Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, String op_name) { + std::cout << "making " << op_name << std::endl; + auto attrs = make_object(); + attrs->axis = std::move(axis); + attrs->keepdims = keepdims; + attrs->exclude = exclude; + return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); +} + #define RELAY_REGISTER_REDUCE_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude) { \ - auto attrs = make_object(); \ - attrs->axis = std::move(axis); \ - attrs->keepdims = keepdims; \ - attrs->exclude = exclude; \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {data}, Attrs(attrs), {}); \ + return MakeReduce(data, axis, keepdims, exclude, OpName); \ }); \ RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b1c2d8b23373..9d5f248cb229 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -40,6 +40,7 @@ #include "../../transforms/infer_layout_util.h" #include "../../transforms/pattern_util.h" +#include "../make_op.h" #include "../op_common.h" namespace tvm { diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index c68dfba784c7..4e5677a1af6d 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -35,11 +35,11 @@ #include #include +#include "../make_op.h" + namespace tvm { namespace relay { -Expr MakeReshape(Expr data, Array newshape); - template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 99e6c026f8d1..958b8b535873 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -27,6 +27,7 @@ #include #include +#include "../make_op.h" #include "../op_common.h" #include "../type_relations.h" @@ -266,13 +267,15 @@ RELAY_REGISTER_UNARY_OP("copy") // relay.clip TVM_REGISTER_NODE_TYPE(ClipAttrs); -TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed([](Expr a, double a_min, double a_max) { +Expr MakeClip(Expr a, double a_min, double a_max) { auto attrs = make_object(); attrs->a_min = a_min; attrs->a_max = a_max; static const Op& op = Op::Get("clip"); return Call(op, {a}, Attrs(attrs), {}); -}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed(MakeClip); RELAY_REGISTER_OP("clip") .describe(R"code(Clip tensor values. diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index dced5020ca0b..359e1d335bfa 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -33,48 +33,32 @@ namespace relay { class DynamicToStaticMutator : public MixedModeMutator { public: - DynamicToStaticMutator() - : dyn_reshape_op_(Op::Get("dyn.reshape")), - dyn_tile_op_(Op::Get("dyn.tile")), - dyn_topk_op_(Op::Get("dyn.topk")) {} + DynamicToStaticMutator() {} private: Expr Rewrite_(const CallNode* pre, const Expr& post) override { const CallNode* call_node = post.as(); - if (call_node->op == dyn_reshape_op_) { + if (call_node->op == Op::Get("dyn.reshape")) { if (const ConstantNode* shape = call_node->args[1].as()) { - auto attrs = make_object(); CHECK_EQ(shape->data->ndim, 1); - attrs->newshape = ToVector(shape->data); - attrs->reverse = false; - static const Op& reshape = Op::Get("reshape"); - return Call(reshape, {call_node->args[0]}, Attrs(attrs), {}); + return MakeReshape(call_node->args[0], ToVector(shape->data)); } - } else if (call_node->op == dyn_tile_op_) { + } else if (call_node->op == Op::Get("dyn.tile")) { if (const ConstantNode* reps = call_node->args[1].as()) { - auto attrs = make_object(); CHECK_EQ(reps->data->ndim, 1); - attrs->reps = ToVector(reps->data); - static const Op& op = Op::Get("tile"); - return Call(op, {call_node->args[0]}, Attrs(attrs), {}); + return MakeTile(call_node->args[0], ToVector(reps->data)); } - } else if (call_node->op == dyn_topk_op_) { + } else if (call_node->op == Op::Get("dyn.topk")) { if (const ConstantNode* k = call_node->args[1].as()) { const TopKAttrs* param = call_node->attrs.as(); CHECK(param); - auto attrs = make_object(); - attrs->k = Integer(ToScalar(k->data, 0)); - std::cout << attrs->k << std::endl; - attrs->axis = param->axis; - attrs->ret_type = param->ret_type; - attrs->is_ascend = param->is_ascend; - attrs->dtype = param->dtype; - static const Op& op = Op::Get("topk"); - return Call(op, {call_node->args[0]}, Attrs(attrs), {}); + return MakeTopK(call_node->args[0], static_cast(ToScalar(k->data, 0)), param->axis, + param->ret_type, param->is_ascend, param->dtype); } } return post; } + Expr DispatchVisitExpr(const Expr& expr) override { auto post = MixedModeMutator::DispatchVisitExpr(expr); if (auto op = post.as()) { @@ -82,10 +66,6 @@ class DynamicToStaticMutator : public MixedModeMutator { } return post; } - - const Op& dyn_reshape_op_; - const Op& dyn_tile_op_; - const Op& dyn_topk_op_; }; Expr DynamicToStatic(Function f, IRModule m) { diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 78068e88a510..62a58d2b7ffb 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -42,6 +42,8 @@ #include #include +#include "../op/make_op.h" + namespace tvm { namespace relay { @@ -448,12 +450,7 @@ T GetScalarFromConstant(Expr expr) { return static_cast(n->data->data)[0]; } -inline Expr Cast(Expr x, DataType dtype) { - static const Op& op = Op::Get("cast"); - auto attrs = make_object(); - attrs->dtype = dtype; - return Call(op, {x}, Attrs(attrs), {}); -} +inline Expr Cast(Expr x, DataType dtype) { return MakeCast(x, dtype); } inline Expr Negative(Expr x) { static const Op& op = Op::Get("negative"); @@ -475,13 +472,7 @@ inline Expr Round(Expr x) { return Call(op, {x}, Attrs(), {}); } -inline Expr Clip(Expr x, double a_min, double a_max) { - static const Op& op = Op::Get("clip"); - auto attrs = make_object(); - attrs->a_min = a_min; - attrs->a_max = a_max; - return Call(op, {x}, Attrs(attrs), {}); -} +inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); } inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); @@ -513,8 +504,6 @@ inline Expr ZerosLike(Expr e) { return Call(op, {e}); } -Expr MakeZeros(Expr shape, DataType dtype); - inline Expr Zeros(Array shape, DataType dtype) { return MakeZeros(CheckConstantShape(shape), dtype); } @@ -524,8 +513,6 @@ inline Expr OnesLike(Expr e) { return Call(op, {e}); } -Expr MakeOnes(Expr shape, DataType dtype); - inline Expr Ones(Array shape, DataType dtype) { return MakeOnes(CheckConstantShape(shape), dtype); } @@ -561,21 +548,11 @@ inline Expr Copy(Expr data) { } inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { - auto attrs = make_object(); - attrs->axis = std::move(axis); - attrs->keepdims = keepdims; - attrs->exclude = exclude; - static const Op& op = Op::Get("mean"); - return Call(op, {data}, Attrs(attrs), {}); + return MakeReduce(data, axis, keepdims, exclude, "mean"); } inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { - auto attrs = make_object(); - attrs->axis = std::move(axis); - attrs->keepdims = keepdims; - attrs->exclude = exclude; - static const Op& op = Op::Get("variance"); - return Call(op, {data, mean}, Attrs(attrs), {}); + return MakeVariance(data, mean, axis, keepdims, exclude); } static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { @@ -588,8 +565,6 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { return Call(op, {lhs, rhs}, Attrs(), {}); } -Expr MakeFull(Expr fill_value, Expr shape, DataType dtype); - static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { return MakeFull(fill_value, CheckConstantShape(shape), dtype); } @@ -598,40 +573,19 @@ static inline Expr Conv2D(Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, Array kernel_size, std::string data_layout, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_object(); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = std::move(channels); - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - static const Op& op = Op::Get("nn.conv2d"); - return Call(op, {data, weight}, Attrs(attrs), {}); + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv2d"); } static inline Expr Dense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { - auto attrs = make_object(); - attrs->units = units; - attrs->out_dtype = out_dtype; - static const Op& op = Op::Get("nn.dense"); - return Call(op, {data, weight}, Attrs(attrs), {}); + return MakeDense(data, weight, units, out_dtype); } static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclude) { - auto attrs = make_object(); - attrs->axis = std::move(axis); - attrs->keepdims = keepdims; - attrs->exclude = exclude; - static const Op& op = Op::Get("sum"); - return Call(op, {data}, Attrs(attrs), {}); + return MakeReduce(data, axis, keepdims, exclude, "sum"); } -Expr MakeReshape(Expr data, Array newshape); - static inline Expr Reshape(Expr data, Array newshape) { return MakeReshape(data, newshape); } @@ -639,56 +593,21 @@ static inline Expr Reshape(Expr data, Array newshape) { static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, Array padding, std::string layout, bool ceil_mode, bool count_include_pad) { - auto attrs = make_object(); - attrs->pool_size = std::move(pool_size); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->layout = std::move(layout); - attrs->ceil_mode = ceil_mode; - attrs->count_include_pad = count_include_pad; - static const Op& op = Op::Get("nn.avg_pool2d"); - return Call(op, {data}, Attrs(attrs), {}); + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool2d"); } static inline Expr Pad(Expr data, Array> pad_width, double pad_value, std::string pad_mode) { - auto attrs = make_object(); - attrs->pad_value = pad_value; - attrs->pad_width = std::move(pad_width); - attrs->pad_mode = std::move(pad_mode); - static const Op& op = Op::Get("nn.pad"); - return Call(op, {data}, Attrs(attrs), {}); -} - -static inline Expr Tile(Expr data, Array reps) { - auto attrs = make_object(); - attrs->reps = reps; - static const Op& op = Op::Get("tile"); - return Call(op, {data}, Attrs(attrs), {}); + return MakePad(data, pad_width, pad_value, pad_mode); } -Expr MakeBroadCastTo(Expr data, Expr shape); +static inline Expr Tile(Expr data, Array reps) { return MakeTile(data, reps); } static inline Expr BroadCastTo(Expr data, Array shape) { return MakeBroadCastTo(data, CheckConstantShape(shape)); } -Expr MakeConcatenate(Expr data, int axis); - -Expr MakeRepeat(Expr data, int repeats, int axis); - -Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode); - -Expr MakeStack(Expr data, int axis); - -Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); - -Expr MakeSqueeze(Expr data, Array axis); - -Expr MakeExpandDims(Expr data, int axis, int num_newaxis); - -Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); - Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype);