From c7120cb5182e76018fbd3c4dc6a66adb45ff633c Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Mon, 26 Aug 2019 15:02:13 -0700 Subject: [PATCH 01/21] Refactor to create abstract ParallelOpCombiner --- src/relay/pass/combine_parallel_conv2d.cc | 167 +++++----------------- src/relay/pass/combine_parallel_op.cc | 136 ++++++++++++++++++ src/relay/pass/combine_parallel_op.h | 115 +++++++++++++++ 3 files changed, 285 insertions(+), 133 deletions(-) create mode 100644 src/relay/pass/combine_parallel_op.cc create mode 100644 src/relay/pass/combine_parallel_op.h diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index d72705c8ce47..ee479aaafba0 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * * \file combine_parallel_conv2d.cc * \brief Combine parallel 2d convolutions into a single convolution. @@ -43,66 +43,22 @@ #include #include "./expr_subst.h" #include "./pattern_util.h" - +#include "./combine_parallel_op.h" namespace tvm { namespace relay { -using Branch = std::vector; -using Group = std::vector; - -/* - Find parallel branches starting with conv2d as shown below and then group branches by kernel - shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops. - Intermediate nodes have exactly one successor. It is possible that branches meet at a point, - which should be handled in ParallelConv2DCombiner. - - data - / \ - conv2d conv2d - | | - op op - | | -*/ -class BranchGroupFinder : private ExprVisitor { +class ParallelConv2DCombiner : public ParallelOpCombiner { public: - std::vector Find(const Expr& expr) { - static const Op& conv2d = Op::Get("nn.conv2d"); - - this->VisitExpr(expr); - - std::vector groups; - for (const auto& root : conv_roots_) { - const auto& children = children_map_.at(root); - size_t ngroups = groups.size(); - for (const CallNode* child : children) { - if (!child->op.same_as(conv2d)) continue; - - auto&& branch = CreateBranch(child); - // add the branch to a group, or create a new group - auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) { - CHECK(!group.empty() && !group[0].empty()); - return IsCompatibleConv2D(child, group[0][0]); - }); - if (it != groups.end()) { - it->push_back(branch); - } else { - groups.emplace_back(); - // each group has at least one branch - groups.back().push_back(branch); - } - } - } - return groups; + ParallelConv2DCombiner(uint64_t min_num_branches) : ParallelOpCombiner("nn.conv2d", min_num_branches) { } - private: - std::unordered_set conv_roots_; - std::unordered_map, NodeHash, NodeEqual> children_map_; + protected: + virtual bool IsSupportedOp(const CallNode* n) { + return n->attrs.as()->groups == 1; + } - // Two 2d convolutions can be combined if they have the same attributes or - // only have different output channels. - bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) { + virtual bool AreCompatibleOps(const CallNode* a, const CallNode* b) { AttrsEqual eq; static const Layout kOIHW("OIHW"); const auto* attrs_a = a->attrs.as(); @@ -125,59 +81,34 @@ class BranchGroupFinder : private ExprVisitor { eq(shape_a[3], shape_b[3]); } - // Create a branch starting from conv2d. - Branch CreateBranch(const CallNode* conv) { - static auto fpattern = Op::GetAttr("TOpPattern"); - // each branch has at least one element, the first element is always conv2d - Branch branch{conv}; - auto it = children_map_.find(GetRef(branch.back())); - while (it != children_map_.end() && it->second.size() == 1) { - const CallNode* call = it->second[0]; - auto pattern = fpattern[Downcast(call->op)]; - if (pattern <= kBroadcast) { - branch.push_back(call); - it = children_map_.find(GetRef(branch.back())); - } else { - break; - } - } - return branch; - } - - void VisitExpr_(const CallNode* n) final { - static const Op& conv2d = Op::Get("nn.conv2d"); - ExprVisitor::VisitExpr_(n); - if (n->op.same_as(conv2d) && n->attrs.as()->groups == 1) { - conv_roots_.insert(n->args[0]); - children_map_[n->args[0]].push_back(n); - } else { - for (size_t i = 0; i < n->args.size(); i++) { - children_map_[n->args[i]].push_back(n); - } - } - } -}; - -class ParallelConv2DCombiner { - public: - explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) { - } - - Expr Combine(const Expr& expr) { - auto groups = BranchGroupFinder().Find(expr); - for (const Group& group : groups) { - if (group.size() < min_num_branches_) { - continue; + virtual void CombineBranches(const Group& branches, ExprSubstMap& subst_map) { + Call combined = MakeCombinedConv2D(branches); + auto conv_param = combined->attrs.as(); + const std::string& layout = + conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout; + size_t channel_pos = layout.find('C'); + CHECK_NE(channel_pos, std::string::npos); + auto it = std::min_element(branches.begin(), branches.end(), + [](const Branch& branch_a, + const Branch& branch_b) { + return branch_a.size() < branch_b.size(); + }); + size_t depth = it->size(); + size_t i; + // starting from 1 to skip the conv2d + for (i = 1; i < depth; i++) { + size_t parent_index; + for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { + if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; } - CombineBranches(group); + CHECK_NE(parent_index, branches[0][i]->args.size()); + if (!CheckLevel(branches, i, channel_pos, parent_index)) break; + combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index); } - return ExprSubst(expr, std::move(subst_map_)); + UpdateGroupOutput(combined, branches, i - 1, channel_pos, subst_map); } private: - std::unordered_map subst_map_; - uint64_t min_num_branches_; - std::tuple TransformWeight(const Group& branches) { int64_t num_filters = 0; // number of filters of the transformed weight Array weights; @@ -300,7 +231,7 @@ class ParallelConv2DCombiner { // Replace output of each branch with slices of the combined output void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, - size_t channel_pos) { + size_t channel_pos, ExprSubstMap& subst_map) { int64_t index = 0; for (const auto& branch : branches) { const CallNode* conv2d = branch[0]; @@ -315,38 +246,8 @@ class ParallelConv2DCombiner { index += channels; end.push_back(index); auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array{}); - subst_map_[GetRef(branch[depth])] = slice; - } - } - - // Combine branches in a group. Conv2d in different branches in the same group are safe to - // combine. Subsequent ops may or may not be combined. We start from conv2d and try to - // combine ops from all branches in the same depth. - void CombineBranches(const Group& branches) { - Call combined = MakeCombinedConv2D(branches); - auto conv_param = combined->attrs.as(); - const std::string& layout = - conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout; - size_t channel_pos = layout.find('C'); - CHECK_NE(channel_pos, std::string::npos); - auto it = std::min_element(branches.begin(), branches.end(), - [](const Branch& branch_a, - const Branch& branch_b) { - return branch_a.size() < branch_b.size(); - }); - size_t depth = it->size(); - size_t i; - // starting from 1 to skip the conv2d - for (i = 1; i < depth; i++) { - size_t parent_index; - for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { - if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; - } - CHECK_NE(parent_index, branches[0][i]->args.size()); - if (!CheckLevel(branches, i, channel_pos, parent_index)) break; - combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index); + subst_map[GetRef(branch[depth])] = slice; } - UpdateGroupOutput(combined, branches, i - 1, channel_pos); } }; diff --git a/src/relay/pass/combine_parallel_op.cc b/src/relay/pass/combine_parallel_op.cc new file mode 100644 index 000000000000..08be9eb9195c --- /dev/null +++ b/src/relay/pass/combine_parallel_op.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file combine_parallel_op.cc + * \brief Abstract class to combine parallel ops and their successive element-wise ops. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" +#include "./combine_parallel_op.h" + + +namespace tvm { +namespace relay { + +BranchGroupFinder::BranchGroupFinder(const std::string& op_name, + FIsSupportedOp fis_supported_op, + FAreCompatibleOps fare_compatible_ops) + : op_name_(op_name), + fis_supported_op_(fis_supported_op), + fare_compatible_ops_(fare_compatible_ops) { +} + +std::vector BranchGroupFinder::Find(const Expr& expr) { + static const Op& op = Op::Get(op_name_); + + this->VisitExpr(expr); + + std::vector groups; + for (const auto& root : op_roots_) { + const auto& children = children_map_.at(root); + size_t ngroups = groups.size(); + for (const CallNode* child : children) { + if (!child->op.same_as(op)) continue; + + auto&& branch = CreateBranch(child); + // add the branch to a group, or create a new group + auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) { + CHECK(!group.empty() && !group[0].empty()); + return fare_compatible_ops_(child, group[0][0]); + }); + if (it != groups.end()) { + it->push_back(branch); + } else { + groups.emplace_back(); + // each group has at least one branch + groups.back().push_back(branch); + } + } + } + return groups; +} + +// Create a branch starting from op. +Branch BranchGroupFinder::CreateBranch(const CallNode* op) { + static auto fpattern = Op::GetAttr("TOpPattern"); + // each branch has at least one element, the first element is always op + Branch branch{op}; + auto it = children_map_.find(GetRef(branch.back())); + while (it != children_map_.end() && it->second.size() == 1) { + const CallNode* call = it->second[0]; + auto pattern = fpattern[Downcast(call->op)]; + if (pattern <= kBroadcast) { + branch.push_back(call); + it = children_map_.find(GetRef(branch.back())); + } else { + break; + } + } + return branch; +} + +void BranchGroupFinder::VisitExpr_(const CallNode* n) { + static const Op& op = Op::Get(op_name_); + ExprVisitor::VisitExpr_(n); + if (n->op.same_as(op) && fis_supported_op_(n)) { + op_roots_.insert(n->args[0]); + children_map_[n->args[0]].push_back(n); + } else { + for (size_t i = 0; i < n->args.size(); i++) { + children_map_[n->args[i]].push_back(n); + } + } +} + +ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) + : op_name_(op_name), + min_num_branches_(min_num_branches) { +} + +Expr ParallelOpCombiner::Combine(const Expr& expr) { + auto groups = BranchGroupFinder(op_name_, + [&](const CallNode* n) { + return IsSupportedOp(n); + }, + [&](const CallNode* a, const CallNode* b) { + return AreCompatibleOps(a, b); + }).Find(expr); + for (const Group& group : groups) { + if (group.size() < min_num_branches_) { + continue; + } + CombineBranches(group, subst_map_); + } + return ExprSubst(expr, std::move(subst_map_)); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h new file mode 100644 index 000000000000..bcb3cd37a6dd --- /dev/null +++ b/src/relay/pass/combine_parallel_op.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file combine_parallel_op.h + * \brief Abstract class to combine parallel ops and their successive element-wise ops. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" + + +namespace tvm { +namespace relay { + +using Branch = std::vector; +using Group = std::vector; +using FIsSupportedOp = std::function; +using FAreCompatibleOps = std::function; +using ExprSubstMap = std::unordered_map; + +/* + Class to find parallel branches starting with op as shown below and then + group branches by kernel shape and attributes of op. + Op can be followed by zero or more elemwise or broadcast ops. + Intermediate nodes have exactly one successor. It is possible that branches meet at a point, + which should be handled in ParallelOpCombiner. + + data + / \ + op op + | | + elem-wise elem-wise + | | +*/ +class BranchGroupFinder : private ExprVisitor { + public: + BranchGroupFinder(const std::string& op_name, + FIsSupportedOp fis_supported_op, + FAreCompatibleOps fare_compatible_ops); + + std::vector Find(const Expr& expr); + + private: + std::string op_name_; + FIsSupportedOp fis_supported_op_; + FAreCompatibleOps fare_compatible_ops_; + std::unordered_set op_roots_; + std::unordered_map, NodeHash, NodeEqual> children_map_; + + // Create a branch starting from op. + Branch CreateBranch(const CallNode* op); + + void VisitExpr_(const CallNode* n) final; +}; + +/* + Abstract class to find and combine parallel ops and the element-wise ops that follow. +*/ +class ParallelOpCombiner { + public: + explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches); + + Expr Combine(const Expr& expr); + + protected: + // Returns true if the op represented by CallNode n is supported to be the + // root of a branch to be combined. Otherwise, returns false. + virtual bool IsSupportedOp(const CallNode* n) = 0; + + // Returns true if ops represented by CallNodes a and b can be combined. + // Otherwise, returns false. + virtual bool AreCompatibleOps(const CallNode* a, const CallNode* b) = 0; + + // Combine branches in a group. Ops in different branches in the same group are safe to + // combine. Subsequent ops may or may not be combined. Start from op and try to + // combine ops from all branches in the same depth. + // Ops should be updated by updating subst_map, + // which maps original Expr to Expr to substitute it with. + virtual void CombineBranches(const Group& branches, ExprSubstMap& subst_map) = 0; + + private: + std::string op_name_; + uint64_t min_num_branches_; + ExprSubstMap subst_map_; +}; + +} // namespace relay +} // namespace tvm From 536c09f094c52d704726fa5d09bf39a9703f5d8b Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Mon, 26 Aug 2019 17:57:30 -0700 Subject: [PATCH 02/21] First draft of CombineParallelDense --- src/relay/pass/combine_parallel_dense.cc | 243 +++++++++++++++++++++++ src/relay/pass/pattern_util.h | 4 + 2 files changed, 247 insertions(+) create mode 100644 src/relay/pass/combine_parallel_dense.cc diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc new file mode 100644 index 000000000000..25c4e7d200c4 --- /dev/null +++ b/src/relay/pass/combine_parallel_dense.cc @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file combine_parallel_dense.cc + * \brief Combine parallel dense ops into a single dense. + * + * This pass replaces dense ops that share the same input node, same shape, + * and don't have "units" defined with a single batch matrix multiplication. + * The inputs of the new batch_matmul is the stack of the original inputs. + * Elemwise and broadcast ops following dense are also combined if possible. + * + * This prevents launching multiple kernels in networks with multiple + * dense branches, such as BERT. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" +#include "./combine_parallel_op.h" + +namespace tvm { +namespace relay { + +class ParallelDenseCombiner : public ParallelOpCombiner { + public: + ParallelDenseCombiner(uint64_t min_num_branches) : ParallelOpCombiner("nn.dense", min_num_branches) { + } + + protected: + virtual bool IsSupportedOp(const CallNode* n) { + const auto* attrs = n->attrs.as(); + return !attrs->units.defined(); + } + + virtual bool AreCompatibleOps(const CallNode* a, const CallNode* b) { + AttrsEqual eq; + const auto* attrs_a = a->attrs.as(); + const auto* attrs_b = b->attrs.as(); + CHECK(attrs_a); + CHECK(attrs_b); + const auto* weight_a = a->args[1]->type_as(); + const auto* weight_b = b->args[1]->type_as(); + + return eq(attrs_a->out_dtype, attrs_b->out_dtype) && + eq(weight_a->shape[0], weight_b->shape[0]) && + eq(weight_a->shape[1], weight_b->shape[1]) && + eq(attrs_a->units.defined(), attrs_b->units.defined()); + } + + virtual void CombineBranches(const Group& branches, ExprSubstMap& subst_map) { + Call combined = MakeCombinedDense(branches); + auto it = std::min_element(branches.begin(), branches.end(), + [](const Branch& branch_a, + const Branch& branch_b) { + return branch_a.size() < branch_b.size(); + }); + size_t depth = it->size(); + size_t i; + // starting from 1 to skip the dense + for (i = 1; i < depth; i++) { + size_t parent_index; + for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { + if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; + } + CHECK_NE(parent_index, branches[0][i]->args.size()); + if (!CheckLevel(branches, i, parent_index)) break; + combined = MakeCombinedCall(combined, branches, i, parent_index); + } + UpdateGroupOutput(combined, branches, i - 1, subst_map); + } + + private: + std::tuple TransformWeight(const Group& branches) { + int64_t num_filters = 0; // number of filters of the transformed weight + Array weights; + for (const auto& branch : branches) { + auto conv2d = branch[0]; + weights.push_back(conv2d->args[1]); + auto channels = GetConv2DSuperChannelsDim(conv2d); + num_filters += channels; + } + auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); + CHECK_NE(index, std::string::npos); + return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), + MakeConstScalar(Int(32), num_filters)); + } + + // Combine dense into batch matmul. + Call MakeCombinedDense(const Group& branches) { + static const Op& batch_matmul = Op::Get("nn.batch_matmul"); + Array datas; + Array weights; + for (const auto& branch : branches) { + auto dense = branch[0]; + auto data = dense->args[0]; + auto weight = dense->args[1]; + datas.push_back(data); + weights.push_back(weight); + } + + Expr new_data = MakeStack(TupleNode::make(datas)); + Expr new_weight = MakeStack(TupleNode::make(weights)); + return CallNode::make(batch_matmul, {new_data, new_weight}, Attrs(), {}); + } + + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { + AttrsEqual eq; + auto ta = a->args[index]->type_as(); + auto tb = b->args[index]->type_as(); + auto toutput_a = a->type_as(); + auto toutput_b = b->type_as(); + + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) + return false; + + for (size_t i = 0; i < ta->shape.size(); i++) { + if (!eq(ta->shape[i], tb->shape[i])) + return false; + } + return true; + } + + // Check if ops in depth-th level can be combined + bool CheckLevel(const Group& branches, size_t depth, size_t parent_index) { + const CallNode* call = branches[0][depth]; + AttrsEqual attrs_equal; + // check if all branches in current depth can be combined + for (auto it = branches.begin() + 1; it != branches.end(); it++) { + const Branch& branch = *it; + if (!branch[depth]->op.same_as(call->op) || + !attrs_equal(branch[depth]->attrs, call->attrs) || + branch[depth]->args.size() != call->args.size()) { + return false; + } + + if (branch[depth]->args[parent_index].get() != branch[depth - 1]) + return false; + + // Check args + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) continue; + + if (!IsArgCompatible(call, branch[depth], i) || + !attrs_equal(call->attrs, branch[depth]->attrs)) { + return false; + } + } + } + return true; + } + + // Combine args and make the combined CallNode + Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t parent_index) { + Array new_args; + const CallNode* call = branches[0][depth]; + + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) { + new_args.push_back(data); + continue; + } + + Array tuple; + for (const auto& branch : branches) { + // if the shape of the arg is 1D, expand it to (1,j) so it can be properly broadcasted. + Expr arg = branch[depth]->args[i]; + const TensorTypeNode* arg_tensor = arg->type_as(); + if (arg_tensor->shape.size() == 1) { + Expr expanded_arg = ExpandBiasToMatchAxis(arg, 2, {0}); + tuple.push_back(expanded_arg); + } else { + tuple.push_back(arg); + } + } + + auto stack = MakeStack(TupleNode::make(tuple)); + new_args.push_back(std::move(stack)); + } + + return CallNode::make(call->op, new_args, call->attrs, {}); + } + + // Replace output of each branch with slices of the combined output + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap& subst_map) { + int index = 0; + auto split = MakeSplit(data, branches.size(), 0); + for (const auto& branch : branches) { + const CallNode* dense = branch[0]; + auto split_data = TupleGetItemNode::make(split, index); + subst_map[GetRef(branch[depth])] = split_data; + } + } +}; + +/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */ +Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) { + return ParallelDenseCombiner(min_num_branches).Combine(expr); +} + +namespace transform { + +Pass CombineParallelDense(uint64_t min_num_branches) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CombineParallelDense(f, min_num_branches)); + }; + return CreateFunctionPass(pass_func, 4, "CombineParallelDense", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CombineParallelDense") +.set_body_typed(CombineParallelDense); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 18e5df3e04df..27a618b0ac1a 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -419,6 +419,10 @@ Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); +Expr MakeStack(Expr data); + +Expr MakeSplit(Expr data, int indices_or_sections, int axis); + Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype); From 8f1502e39951211ca888e00913ed37a2c57fa1bb Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Tue, 27 Aug 2019 11:01:21 -0700 Subject: [PATCH 03/21] Begin to work on tests --- docs/api/python/relay/transform.rst | 2 + include/tvm/relay/transform.h | 11 + python/tvm/relay/transform.py | 18 ++ src/relay/backend/build_module.cc | 1 + src/relay/pass/combine_parallel_dense.cc | 17 +- .../relay/test_pass_combine_parallel_dense.py | 200 ++++++++++++++++++ 6 files changed, 233 insertions(+), 16 deletions(-) create mode 100644 tests/python/relay/test_pass_combine_parallel_dense.py diff --git a/docs/api/python/relay/transform.rst b/docs/api/python/relay/transform.rst index 3c0a6dcf22f6..346152b9c769 100644 --- a/docs/api/python/relay/transform.rst +++ b/docs/api/python/relay/transform.rst @@ -46,6 +46,8 @@ tvm.relay.transform .. autofunction:: tvm.relay.transform.CombineParallelConv2D +.. autofunction:: tvm.relay.transform.CombineParallelDense + .. autofunction:: tvm.relay.transform.AlterOpLayout .. autofunction:: tvm.relay.transform.Legalize diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 4bd59302f0d8..14f25cf90726 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -482,6 +482,17 @@ TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr); */ TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); +/*! + * \brief Combine parallel dense ops into a single batch_matmul if the + * number of branches of this dense operator is not less than + * `min_num_branch`. + * + * \param min_num_branches The minimun number of branches. + * + * \return The pass. + */ +TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3); + /*! * \brief Backward fold axis scaling into weights of conv/dense operators. * diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index ccdf00ed64e3..4562d3488afa 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -138,6 +138,7 @@ def build_config(opt_level=2, "CanonicalizeCast": 3, "EliminateCommonSubexpr": 3, "CombineParallelConv2D": 4, + "CombineParallelDense": 4 } fallback_device : int, str, or tvm.TVMContext, optional @@ -400,6 +401,23 @@ def CombineParallelConv2D(min_num_branches=3): return _transform.CombineParallelConv2D(min_num_branches) +def CombineParallelDense(min_num_branches=3): + """Combine multiple dense operators into one. + + Parameters + ---------- + min_num_branches : int + The minimum number of required parallel branches for performing this + optimization. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that combines parallel dense operators. + """ + return _transform.CombineParallelDense(min_num_branches) + + def AlterOpLayout(): """Alternate the layouts of operators or replace primitive operators with other expressions. diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f757dad520ef..278ef43dd177 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode { }); pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::CombineParallelDense(3)); pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeCast()); diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index 25c4e7d200c4..25fe340fa4f3 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -96,21 +96,6 @@ class ParallelDenseCombiner : public ParallelOpCombiner { } private: - std::tuple TransformWeight(const Group& branches) { - int64_t num_filters = 0; // number of filters of the transformed weight - Array weights; - for (const auto& branch : branches) { - auto conv2d = branch[0]; - weights.push_back(conv2d->args[1]); - auto channels = GetConv2DSuperChannelsDim(conv2d); - num_filters += channels; - } - auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); - CHECK_NE(index, std::string::npos); - return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), - MakeConstScalar(Int(32), num_filters)); - } - // Combine dense into batch matmul. Call MakeCombinedDense(const Group& branches) { static const Op& batch_matmul = Op::Get("nn.batch_matmul"); @@ -218,7 +203,7 @@ class ParallelDenseCombiner : public ParallelOpCombiner { } }; -/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */ +/*! \brief Combine parallel dense if number of branches >= min_num_branches */ Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) { return ParallelDenseCombiner(min_num_branches).Combine(expr); } diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py new file mode 100644 index 000000000000..227606a5bdee --- /dev/null +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm import relay +from tvm.relay import transform + + +def run_combine_parallel(expr, min_num_branches=3): + mod = relay.Module.from_expr(expr) + mod = transform.CombineParallelDense(min_num_branches)(mod) + return mod["main"] + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + return mod["main"] + + +def test_combine_parallel_dense(): + """Simple testcase.""" + def before(x, w1, w2, w3, w4): + args = [x, w1, w2, w3, w4] + y1 = relay.nn.dense(x, w1) + y2 = relay.nn.dense(x, w2) + # y3 cannot be combined + y3 = relay.nn.dense(x, w3) + y4 = relay.nn.dense(x, w4) + y = relay.Tuple((y1, y2, y3, y4)) + return relay.Function(args, y) + + def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): + # use a fixed order of args so alpha equal check can pass + args = [x, w1, w2, w3, w4] + w = relay.concatenate((w1, w2, w4), axis=0) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4) + y1 = relay.strided_slice(y, [0, 0], [None, channels1]) + y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y3 = relay.nn.conv2d(x, w3) + y4 = relay.strided_slice(y, [0, channels1 + channels2], + [None, channels1 + channels2 + channels4]) + y5 = relay.nn.max_pool2d(x) + y = relay.Tuple((y1, y2, y3, y4, y5)) + return relay.Function(args, y) + + def check(x_shape, channels1, channels2, channels3, channels4): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + w3 = relay.var("w3", shape=(channels3, in_c, 3, 3)) + w4 = relay.var("w4", shape=(channels4, in_c, 1, 1)) + + y_before = before(x, w1, w2, w3, w4) + y = run_opt_pass(y_before, + transform.CombineParallelConv2D(min_num_branches=2)) + y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4, 4, 4, 4) + check((1, 4, 16, 16), 4, 8, 4, 7) + + +def test_combine_parallel_conv2d_scale_relu(): + """Testcase of combining conv2d + scale + relu""" + def before(x, w1, w2, scale1, scale2, bias): + args = [x, w1, w2, scale1, scale2, bias] + y1 = relay.nn.conv2d(x, w1) + y1 = relay.multiply(y1, scale1) + y1 = relay.nn.relu(y1) + y2 = relay.nn.conv2d(x, w2) + y2 = relay.multiply(y2, scale2) + y2 = relay.nn.relu(y2) + y2 = relay.add(y2, bias) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2): + args = [x, w1, w2, scale1, scale2, bias] + w = relay.concatenate((w1, w2), axis=0) + scale = relay.concatenate((scale1, scale2), axis=0) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2) + y = relay.multiply(y, scale) + y = relay.nn.relu(y) + y1 = relay.strided_slice(y, [0, 0], [None, channels1]) + y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y2 = relay.add(y2, bias) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def check(x_shape, channels1, channels2): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + scale1 = relay.var("scale1", shape=(channels1, 1, 1)) + scale2 = relay.var("scale2", shape=(channels2, 1, 1)) + bias = relay.var("bias", shape=(channels2, 1, 1)) + y_before = before(x, w1, w2, scale1, scale2, bias) + y = run_opt_pass(y_before, + transform.CombineParallelConv2D(min_num_branches=2)) + y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4, 8) + + +def test_combine_parallel_conv2d_scale(): + """Testcase of un-combinable scale""" + def before(x, w1, w2, scale1, scale2): + args = [x, w1, w2, scale1, scale2] + y1 = relay.nn.conv2d(x, w1) + y1 = relay.multiply(y1, scale1) + y2 = relay.nn.conv2d(x, w2) + y2 = relay.multiply(y2, scale2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, scale1, scale2, channels1, channels2): + args = [x, w1, w2, scale1, scale2] + w = relay.concatenate((w1, w2), axis=0) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2) + y1 = relay.strided_slice(y, [0, 0], [None, channels1]) + y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y1 = relay.multiply(y1, scale1) + y2 = relay.multiply(y2, scale2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def check(x_shape, channels1, channels2): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + scale1 = relay.var("scale1", shape=(1,)) + scale2 = relay.var("scale2", shape=(1,)) + y_before = before(x, w1, w2, scale1, scale2) + y = run_opt_pass(y_before, + transform.CombineParallelConv2D(min_num_branches=2)) + y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4, 8) + + +def test_combine_parallel_conv2d_multiple_blocks(): + def before(x, w, repeat): + args = [x, w] + y = x + for i in range(repeat): + y1 = relay.nn.conv2d(y, w) + y2 = relay.nn.conv2d(y, w) + y = relay.concatenate((y1, y2), axis=1) + return relay.Function(args, y) + + def expected(x, w, channels, repeat): + args = [x, w] + y = x + for i in range(repeat): + w_concat = relay.concatenate((w, w), axis=0) + y = relay.nn.conv2d(y, w_concat, channels=channels*2) + y1 = relay.strided_slice(y, [0, 0], [None, channels]) + y2 = relay.strided_slice(y, [0, channels], [None, channels * 2]) + y = relay.concatenate((y1, y2), axis=1) + return relay.Function(args, y) + + def check(x_shape, repeat): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + out_c = in_c // 2 + w = relay.var("w", shape=(out_c, in_c, 1, 1)) + y_before = before(x, w, repeat) + y = run_opt_pass(y_before, + transform.CombineParallelConv2D(min_num_branches=2)) + y_expected = expected(x, w, out_c, repeat) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4) + + +if __name__ == "__main__": + test_combine_parallel_dense() + #test_combine_parallel_dense_biasadd() From 2591c35d0ebeced1fdf84a1de149727b35ab7707 Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Thu, 29 Aug 2019 17:16:19 -0700 Subject: [PATCH 04/21] Test --- src/relay/pass/combine_parallel_dense.cc | 15 +- src/relay/pass/pattern_util.h | 8 +- .../relay/test_pass_combine_parallel_dense.py | 244 +++++++++--------- 3 files changed, 140 insertions(+), 127 deletions(-) diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index 25fe340fa4f3..87fba95fc12d 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -109,8 +109,8 @@ class ParallelDenseCombiner : public ParallelOpCombiner { weights.push_back(weight); } - Expr new_data = MakeStack(TupleNode::make(datas)); - Expr new_weight = MakeStack(TupleNode::make(weights)); + Expr new_data = MakeStack(TupleNode::make(datas), 0); + Expr new_weight = MakeStack(TupleNode::make(weights), 0); return CallNode::make(batch_matmul, {new_data, new_weight}, Attrs(), {}); } @@ -177,14 +177,14 @@ class ParallelDenseCombiner : public ParallelOpCombiner { Expr arg = branch[depth]->args[i]; const TensorTypeNode* arg_tensor = arg->type_as(); if (arg_tensor->shape.size() == 1) { - Expr expanded_arg = ExpandBiasToMatchAxis(arg, 2, {0}); + Expr expanded_arg = MakeExpandDims(arg, 0, 1); tuple.push_back(expanded_arg); } else { tuple.push_back(arg); } } - auto stack = MakeStack(TupleNode::make(tuple)); + auto stack = MakeStack(TupleNode::make(tuple), 0); new_args.push_back(std::move(stack)); } @@ -194,11 +194,12 @@ class ParallelDenseCombiner : public ParallelOpCombiner { // Replace output of each branch with slices of the combined output void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap& subst_map) { int index = 0; - auto split = MakeSplit(data, branches.size(), 0); + auto split = MakeSplit(data, Integer(branches.size()), 0); for (const auto& branch : branches) { const CallNode* dense = branch[0]; - auto split_data = TupleGetItemNode::make(split, index); - subst_map[GetRef(branch[depth])] = split_data; + auto split_data = TupleGetItemNode::make(split, index++); + auto squeezed_data = MakeSqueeze(split_data, {0}); + subst_map[GetRef(branch[depth])] = squeezed_data; } } }; diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 27a618b0ac1a..d4f7ebce46d8 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -419,9 +419,13 @@ Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); -Expr MakeStack(Expr data); +Expr MakeStack(Expr data, int axis); -Expr MakeSplit(Expr data, int indices_or_sections, int axis); +Expr MakeSplit(Expr data, NodeRef indices_or_sections, int axis); + +Expr MakeSqueeze(Expr data, Array axis); + +Expr MakeExpandDims(Expr data, int axis, int num_newaxis); Expr StopFusion(Expr data); diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index 227606a5bdee..6a7fa15ce567 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -31,170 +31,178 @@ def run_opt_pass(expr, opt_pass): def test_combine_parallel_dense(): - """Simple testcase.""" - def before(x, w1, w2, w3, w4): + """Simple testcase. Three can be combined, either because of mismatched shapes or units""" + def before(x, w1, w2, w3, w4, units): args = [x, w1, w2, w3, w4] y1 = relay.nn.dense(x, w1) y2 = relay.nn.dense(x, w2) + # y3 cannot be combined - y3 = relay.nn.dense(x, w3) + if units == -1: + y3 = relay.nn.dense(x, w3) + else: + y3 = relay.nn.dense(x, w3, units=units) + y4 = relay.nn.dense(x, w4) y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) - def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): + def expected(x, w1, w2, w3, w4, units): # use a fixed order of args so alpha equal check can pass args = [x, w1, w2, w3, w4] - w = relay.concatenate((w1, w2, w4), axis=0) - y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4) - y1 = relay.strided_slice(y, [0, 0], [None, channels1]) - y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) - y3 = relay.nn.conv2d(x, w3) - y4 = relay.strided_slice(y, [0, channels1 + channels2], - [None, channels1 + channels2 + channels4]) - y5 = relay.nn.max_pool2d(x) - y = relay.Tuple((y1, y2, y3, y4, y5)) + x_stacked = relay.stack((x, x, x), axis=0) + w = relay.stack((w1, w2, w4), axis=0) + y = relay.nn.batch_matmul(x_stacked, w) + (y1, y2, y4) = relay.split(y, 3) + y1 = relay.squeeze(y1, [0]) + y2 = relay.squeeze(y2, [0]) + y4 = relay.squeeze(y4, [0]) + + if units == -1: + y3 = relay.nn.dense(x, w3) + else: + y3 = relay.nn.dense(x, w3, units=units) + + y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) - def check(x_shape, channels1, channels2, channels3, channels4): - x = relay.var("x", shape=x_shape) - in_c = x_shape[1] - w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) - w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) - w3 = relay.var("w3", shape=(channels3, in_c, 3, 3)) - w4 = relay.var("w4", shape=(channels4, in_c, 1, 1)) + def check(i, j, k, use_units): + x = relay.var("x", shape=(i, k)) + w1 = relay.var("w1", shape=(j, k)) + w2 = relay.var("w2", shape=(j, k)) - y_before = before(x, w1, w2, w3, w4) + if use_units: + units = j + w3 = relay.var("w3", shape=(j, k)) + else: + units = -1 + w3 = relay.var("w3", shape=(j + 1, k)) + + w4 = relay.var("w4", shape=(j, k)) + + y_before = before(x, w1, w2, w3, w4, units) y = run_opt_pass(y_before, - transform.CombineParallelConv2D(min_num_branches=2)) - y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) + transform.CombineParallelDense(min_num_branches=2)) + y_expected = expected(x, w1, w2, w3, w4, units) y_expected = run_opt_pass(y_expected, transform.InferType()) assert relay.analysis.alpha_equal(y, y_expected) - check((1, 4, 16, 16), 4, 4, 4, 4) - check((1, 4, 16, 16), 4, 8, 4, 7) + check(3, 5, 4, False) + check(100, 200, 300, False) + check(3, 5, 4, True) + check(100, 200, 300, True) -def test_combine_parallel_conv2d_scale_relu(): - """Testcase of combining conv2d + scale + relu""" - def before(x, w1, w2, scale1, scale2, bias): - args = [x, w1, w2, scale1, scale2, bias] - y1 = relay.nn.conv2d(x, w1) - y1 = relay.multiply(y1, scale1) - y1 = relay.nn.relu(y1) - y2 = relay.nn.conv2d(x, w2) - y2 = relay.multiply(y2, scale2) - y2 = relay.nn.relu(y2) - y2 = relay.add(y2, bias) +def test_combine_parallel_dense_biasadd(): + """Testcase of combining dense + 1d biasadd""" + def before(x, w1, w2, b1, b2): + args = [x, w1, w2, b1, b2] + y1 = relay.nn.dense(x, w1) + y2 = relay.nn.dense(x, w2) + y1 = relay.add(y1, b1) + y2 = relay.add(y2, b2) y = relay.Tuple((y1, y2)) return relay.Function(args, y) - def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2): - args = [x, w1, w2, scale1, scale2, bias] - w = relay.concatenate((w1, w2), axis=0) - scale = relay.concatenate((scale1, scale2), axis=0) - y = relay.nn.conv2d(x, w, channels=channels1 + channels2) - y = relay.multiply(y, scale) - y = relay.nn.relu(y) - y1 = relay.strided_slice(y, [0, 0], [None, channels1]) - y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) - y2 = relay.add(y2, bias) + def expected(x, w1, w2, b1, b2, is_2d_bias): + args = [x, w1, w2, b1, b2] + x_stacked = relay.stack((x, x), axis=0) + w = relay.stack((w1, w2), axis=0) + y = relay.nn.batch_matmul(x_stacked, w) + + if not is_2d_bias: + b1 = relay.expand_dims(b1, 0) + b2 = relay.expand_dims(b2, 0) + + b = relay.stack((b1, b2), axis=0) + y = relay.add(y, b) + (y1, y2) = relay.split(y, 2) + y1 = relay.squeeze(y1, [0]) + y2 = relay.squeeze(y2, [0]) y = relay.Tuple((y1, y2)) return relay.Function(args, y) - def check(x_shape, channels1, channels2): - x = relay.var("x", shape=x_shape) - in_c = x_shape[1] - w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) - w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) - scale1 = relay.var("scale1", shape=(channels1, 1, 1)) - scale2 = relay.var("scale2", shape=(channels2, 1, 1)) - bias = relay.var("bias", shape=(channels2, 1, 1)) - y_before = before(x, w1, w2, scale1, scale2, bias) + def check(i, j, k, is_2d_bias): + x = relay.var("x", shape=(i, k)) + w1 = relay.var("w1", shape=(j, k)) + w2 = relay.var("w2", shape=(j, k)) + + if is_2d_bias: + b1 = relay.var("b1", shape=(i, j)) + b2 = relay.var("b2", shape=(i, j)) + else: + b1 = relay.var("b1", shape=(j,)) + b2 = relay.var("b2", shape=(j,)) + + y_before = before(x, w1, w2, b1, b2) y = run_opt_pass(y_before, - transform.CombineParallelConv2D(min_num_branches=2)) - y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) + transform.CombineParallelDense(min_num_branches=2)) + y_expected = expected(x, w1, w2, b1, b2, is_2d_bias) y_expected = run_opt_pass(y_expected, transform.InferType()) assert relay.analysis.alpha_equal(y, y_expected) - check((1, 4, 16, 16), 4, 8) - + check(3, 5, 4, False) + check(100, 200, 300, False) + check(3, 5, 4, True) + check(100, 200, 300, True) -def test_combine_parallel_conv2d_scale(): - """Testcase of un-combinable scale""" - def before(x, w1, w2, scale1, scale2): - args = [x, w1, w2, scale1, scale2] - y1 = relay.nn.conv2d(x, w1) +def test_combine_parallel_dense_biasadd_scale_reshape(): + """Testcase of combining dense + 1d biasadd""" + def before(x, w1, w2, b1, b2, scale1, scale2, newshape): + args = [x, w1, w2, b1, b2, scale1, scale2] + y1 = relay.nn.dense(x, w1) + y2 = relay.nn.dense(x, w2) + y1 = relay.add(y1, b1) + y2 = relay.add(y2, b2) y1 = relay.multiply(y1, scale1) - y2 = relay.nn.conv2d(x, w2) y2 = relay.multiply(y2, scale2) + y1 = relay.reshape(y1, newshape=newshape) + y2 = relay.reshape(y2, newshape=newshape) y = relay.Tuple((y1, y2)) return relay.Function(args, y) - def expected(x, w1, w2, scale1, scale2, channels1, channels2): - args = [x, w1, w2, scale1, scale2] - w = relay.concatenate((w1, w2), axis=0) - y = relay.nn.conv2d(x, w, channels=channels1 + channels2) - y1 = relay.strided_slice(y, [0, 0], [None, channels1]) - y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) - y1 = relay.multiply(y1, scale1) - y2 = relay.multiply(y2, scale2) + def expected(x, w1, w2, b1, b2, scale1, scale2, newshape): + args = [x, w1, w2, b1, b2, scale1, scale2] + x_stacked = relay.stack((x, x), axis=0) + w = relay.stack((w1, w2), axis=0) + y = relay.nn.batch_matmul(x_stacked, w) + b1 = relay.expand_dims(b1, 0) + b2 = relay.expand_dims(b2, 0) + b = relay.stack((b1, b2), axis=0) + y = relay.add(y, b) + scale1 = relay.expand_dims(scale1, 0) + scale2 = relay.expand_dims(scale2, 0) + scale = relay.stack((scale1, scale2), axis=0) + y = relay.multiply(y, scale) + (y1, y2) = relay.split(y, 2) + y1 = relay.squeeze(y1, [0]) + y2 = relay.squeeze(y2, [0]) + y1 = relay.reshape(y1, newshape=newshape) + y2 = relay.reshape(y2, newshape=newshape) y = relay.Tuple((y1, y2)) return relay.Function(args, y) - def check(x_shape, channels1, channels2): - x = relay.var("x", shape=x_shape) - in_c = x_shape[1] - w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) - w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + def check(i, j, k, scale1, scale2, newshape): + x = relay.var("x", shape=(i, k)) + w1 = relay.var("w1", shape=(j, k)) + w2 = relay.var("w2", shape=(j, k)) + b1 = relay.var("b1", shape=(j,)) + b2 = relay.var("b2", shape=(j,)) scale1 = relay.var("scale1", shape=(1,)) scale2 = relay.var("scale2", shape=(1,)) - y_before = before(x, w1, w2, scale1, scale2) - y = run_opt_pass(y_before, - transform.CombineParallelConv2D(min_num_branches=2)) - y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) - y_expected = run_opt_pass(y_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y, y_expected) - - check((1, 4, 16, 16), 4, 8) - - -def test_combine_parallel_conv2d_multiple_blocks(): - def before(x, w, repeat): - args = [x, w] - y = x - for i in range(repeat): - y1 = relay.nn.conv2d(y, w) - y2 = relay.nn.conv2d(y, w) - y = relay.concatenate((y1, y2), axis=1) - return relay.Function(args, y) - - def expected(x, w, channels, repeat): - args = [x, w] - y = x - for i in range(repeat): - w_concat = relay.concatenate((w, w), axis=0) - y = relay.nn.conv2d(y, w_concat, channels=channels*2) - y1 = relay.strided_slice(y, [0, 0], [None, channels]) - y2 = relay.strided_slice(y, [0, channels], [None, channels * 2]) - y = relay.concatenate((y1, y2), axis=1) - return relay.Function(args, y) - def check(x_shape, repeat): - x = relay.var("x", shape=x_shape) - in_c = x_shape[1] - out_c = in_c // 2 - w = relay.var("w", shape=(out_c, in_c, 1, 1)) - y_before = before(x, w, repeat) + y_before = before(x, w1, w2, b1, b2, scale1, scale2, newshape) y = run_opt_pass(y_before, - transform.CombineParallelConv2D(min_num_branches=2)) - y_expected = expected(x, w, out_c, repeat) + transform.CombineParallelDense(min_num_branches=2)) + y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape) y_expected = run_opt_pass(y_expected, transform.InferType()) assert relay.analysis.alpha_equal(y, y_expected) - check((1, 4, 16, 16), 4) + check(3, 5, 4, 0.5, 0.25, (1, 1, 15)) + check(100, 200, 300, 0.5, 0.25, (1, 1, 200)) if __name__ == "__main__": test_combine_parallel_dense() - #test_combine_parallel_dense_biasadd() + test_combine_parallel_dense_biasadd() + test_combine_parallel_dense_biasadd_scale_reshape() From b9c1128f45c742427207e258fb2c50eed57f4643 Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Thu, 29 Aug 2019 17:55:14 -0700 Subject: [PATCH 05/21] Refactor to move out more common code --- src/relay/pass/combine_parallel_conv2d.cc | 111 ++++++---------------- src/relay/pass/combine_parallel_dense.cc | 63 +----------- src/relay/pass/combine_parallel_op.cc | 52 +++++++++- src/relay/pass/combine_parallel_op.h | 23 +++-- 4 files changed, 100 insertions(+), 149 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index ee479aaafba0..f59dc302f3c1 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -54,11 +54,11 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { } protected: - virtual bool IsSupportedOp(const CallNode* n) { + bool IsSupportedOp(const CallNode* n) { return n->attrs.as()->groups == 1; } - virtual bool AreCompatibleOps(const CallNode* a, const CallNode* b) { + bool AreCompatibleOps(const CallNode* a, const CallNode* b) { AttrsEqual eq; static const Layout kOIHW("OIHW"); const auto* attrs_a = a->attrs.as(); @@ -81,50 +81,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { eq(shape_a[3], shape_b[3]); } - virtual void CombineBranches(const Group& branches, ExprSubstMap& subst_map) { - Call combined = MakeCombinedConv2D(branches); - auto conv_param = combined->attrs.as(); - const std::string& layout = - conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout; - size_t channel_pos = layout.find('C'); - CHECK_NE(channel_pos, std::string::npos); - auto it = std::min_element(branches.begin(), branches.end(), - [](const Branch& branch_a, - const Branch& branch_b) { - return branch_a.size() < branch_b.size(); - }); - size_t depth = it->size(); - size_t i; - // starting from 1 to skip the conv2d - for (i = 1; i < depth; i++) { - size_t parent_index; - for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { - if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; - } - CHECK_NE(parent_index, branches[0][i]->args.size()); - if (!CheckLevel(branches, i, channel_pos, parent_index)) break; - combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index); - } - UpdateGroupOutput(combined, branches, i - 1, channel_pos, subst_map); - } - - private: - std::tuple TransformWeight(const Group& branches) { - int64_t num_filters = 0; // number of filters of the transformed weight - Array weights; - for (const auto& branch : branches) { - auto conv2d = branch[0]; - weights.push_back(conv2d->args[1]); - auto channels = GetConv2DSuperChannelsDim(conv2d); - num_filters += channels; - } - auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); - CHECK_NE(index, std::string::npos); - return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), - MakeConstScalar(Int(32), num_filters)); - } - - Call MakeCombinedConv2D(const Group& branches) { + Call MakeCombinedOp(const Group& branches) { static const Op& conv2d = Op::Get("nn.conv2d"); Expr data = branches[0][0]->args[0]; Expr new_weight; @@ -146,10 +103,15 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { new_attrs->out_dtype = attrs->out_dtype; new_attrs->channels = new_channels; + const std::string& layout = + new_attrs->out_layout == "" ? new_attrs->data_layout : new_attrs->out_layout; + channel_pos = layout.find('C'); + CHECK_NE(channel_pos, std::string::npos); + return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); } - bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) { + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { AttrsEqual eq; auto ta = a->args[index]->type_as(); auto tb = b->args[index]->type_as(); @@ -176,38 +138,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { return true; } - // Check if ops in depth-th level can be combined - bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) { - const CallNode* call = branches[0][depth]; - AttrsEqual attrs_equal; - // check if all branches in current depth can be combined - for (auto it = branches.begin() + 1; it != branches.end(); it++) { - const Branch& branch = *it; - if (!branch[depth]->op.same_as(call->op) || - !attrs_equal(branch[depth]->attrs, call->attrs) || - branch[depth]->args.size() != call->args.size()) { - return false; - } - - if (branch[depth]->args[parent_index].get() != branch[depth - 1]) - return false; - - // Check args - for (size_t i = 0; i < call->args.size(); i++) { - if (i == parent_index) continue; - - if (!IsArgCompatible(call, branch[depth], i, channel_pos) || - !attrs_equal(call->attrs, branch[depth]->attrs)) { - return false; - } - } - } - return true; - } - - // Combine args and make the combined CallNode - Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos, - size_t parent_index) { + Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; size_t ndim = call->type_as()->shape.size(); @@ -229,9 +160,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { return CallNode::make(call->op, new_args, call->attrs, {}); } - // Replace output of each branch with slices of the combined output - void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, - size_t channel_pos, ExprSubstMap& subst_map) { + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap& subst_map) { int64_t index = 0; for (const auto& branch : branches) { const CallNode* conv2d = branch[0]; @@ -249,6 +178,24 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { subst_map[GetRef(branch[depth])] = slice; } } + + private: + size_t channel_pos; + + std::tuple TransformWeight(const Group& branches) { + int64_t num_filters = 0; // number of filters of the transformed weight + Array weights; + for (const auto& branch : branches) { + auto conv2d = branch[0]; + weights.push_back(conv2d->args[1]); + auto channels = GetConv2DSuperChannelsDim(conv2d); + num_filters += channels; + } + auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); + CHECK_NE(index, std::string::npos); + return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), + MakeConstScalar(Int(32), num_filters)); + } }; /*! \brief Combine parallel conv2d if number of branches >= min_num_branches */ diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index 87fba95fc12d..08a02403fa9f 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -53,12 +53,12 @@ class ParallelDenseCombiner : public ParallelOpCombiner { } protected: - virtual bool IsSupportedOp(const CallNode* n) { + bool IsSupportedOp(const CallNode* n) { const auto* attrs = n->attrs.as(); return !attrs->units.defined(); } - virtual bool AreCompatibleOps(const CallNode* a, const CallNode* b) { + bool AreCompatibleOps(const CallNode* a, const CallNode* b) { AttrsEqual eq; const auto* attrs_a = a->attrs.as(); const auto* attrs_b = b->attrs.as(); @@ -73,31 +73,7 @@ class ParallelDenseCombiner : public ParallelOpCombiner { eq(attrs_a->units.defined(), attrs_b->units.defined()); } - virtual void CombineBranches(const Group& branches, ExprSubstMap& subst_map) { - Call combined = MakeCombinedDense(branches); - auto it = std::min_element(branches.begin(), branches.end(), - [](const Branch& branch_a, - const Branch& branch_b) { - return branch_a.size() < branch_b.size(); - }); - size_t depth = it->size(); - size_t i; - // starting from 1 to skip the dense - for (i = 1; i < depth; i++) { - size_t parent_index; - for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { - if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; - } - CHECK_NE(parent_index, branches[0][i]->args.size()); - if (!CheckLevel(branches, i, parent_index)) break; - combined = MakeCombinedCall(combined, branches, i, parent_index); - } - UpdateGroupOutput(combined, branches, i - 1, subst_map); - } - - private: - // Combine dense into batch matmul. - Call MakeCombinedDense(const Group& branches) { + Call MakeCombinedOp(const Group& branches) { static const Op& batch_matmul = Op::Get("nn.batch_matmul"); Array datas; Array weights; @@ -118,8 +94,6 @@ class ParallelDenseCombiner : public ParallelOpCombiner { AttrsEqual eq; auto ta = a->args[index]->type_as(); auto tb = b->args[index]->type_as(); - auto toutput_a = a->type_as(); - auto toutput_b = b->type_as(); if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; @@ -131,36 +105,6 @@ class ParallelDenseCombiner : public ParallelOpCombiner { return true; } - // Check if ops in depth-th level can be combined - bool CheckLevel(const Group& branches, size_t depth, size_t parent_index) { - const CallNode* call = branches[0][depth]; - AttrsEqual attrs_equal; - // check if all branches in current depth can be combined - for (auto it = branches.begin() + 1; it != branches.end(); it++) { - const Branch& branch = *it; - if (!branch[depth]->op.same_as(call->op) || - !attrs_equal(branch[depth]->attrs, call->attrs) || - branch[depth]->args.size() != call->args.size()) { - return false; - } - - if (branch[depth]->args[parent_index].get() != branch[depth - 1]) - return false; - - // Check args - for (size_t i = 0; i < call->args.size(); i++) { - if (i == parent_index) continue; - - if (!IsArgCompatible(call, branch[depth], i) || - !attrs_equal(call->attrs, branch[depth]->attrs)) { - return false; - } - } - } - return true; - } - - // Combine args and make the combined CallNode Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; @@ -191,7 +135,6 @@ class ParallelDenseCombiner : public ParallelOpCombiner { return CallNode::make(call->op, new_args, call->attrs, {}); } - // Replace output of each branch with slices of the combined output void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap& subst_map) { int index = 0; auto split = MakeSplit(data, Integer(branches.size()), 0); diff --git a/src/relay/pass/combine_parallel_op.cc b/src/relay/pass/combine_parallel_op.cc index 08be9eb9195c..3cf9bda3fd96 100644 --- a/src/relay/pass/combine_parallel_op.cc +++ b/src/relay/pass/combine_parallel_op.cc @@ -127,10 +127,60 @@ Expr ParallelOpCombiner::Combine(const Expr& expr) { if (group.size() < min_num_branches_) { continue; } - CombineBranches(group, subst_map_); + CombineBranches(group); } return ExprSubst(expr, std::move(subst_map_)); } +void ParallelOpCombiner::CombineBranches(const Group& branches) { + Call combined = MakeCombinedOp(branches); + auto it = std::min_element(branches.begin(), branches.end(), + [](const Branch& branch_a, + const Branch& branch_b) { + return branch_a.size() < branch_b.size(); + }); + size_t depth = it->size(); + size_t i; + // starting from 1 to skip the dense + for (i = 1; i < depth; i++) { + size_t parent_index; + for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { + if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; + } + CHECK_NE(parent_index, branches[0][i]->args.size()); + if (!CheckLevel(branches, i, parent_index)) break; + combined = MakeCombinedCall(combined, branches, i, parent_index); + } + UpdateGroupOutput(combined, branches, i - 1, subst_map_); +} + +bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { + const CallNode* call = branches[0][depth]; + AttrsEqual attrs_equal; + // check if all branches in current depth can be combined + for (auto it = branches.begin() + 1; it != branches.end(); it++) { + const Branch& branch = *it; + if (!branch[depth]->op.same_as(call->op) || + !attrs_equal(branch[depth]->attrs, call->attrs) || + branch[depth]->args.size() != call->args.size()) { + return false; + } + + if (branch[depth]->args[parent_index].get() != branch[depth - 1]) + return false; + + // Check args + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) continue; + + if (!IsArgCompatible(call, branch[depth], i) || + !attrs_equal(call->attrs, branch[depth]->attrs)) { + return false; + } + } + } + return true; + } + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index bcb3cd37a6dd..19f424a90fbe 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -98,17 +98,28 @@ class ParallelOpCombiner { // Otherwise, returns false. virtual bool AreCompatibleOps(const CallNode* a, const CallNode* b) = 0; - // Combine branches in a group. Ops in different branches in the same group are safe to - // combine. Subsequent ops may or may not be combined. Start from op and try to - // combine ops from all branches in the same depth. - // Ops should be updated by updating subst_map, - // which maps original Expr to Expr to substitute it with. - virtual void CombineBranches(const Group& branches, ExprSubstMap& subst_map) = 0; + // Create Call that consists of the combined ops. This usually involves concatenating + // or stacking inputs, then creating a new call. + virtual Call MakeCombinedOp(const Group& branches) = 0; + + // Returns true if arguments of a and b at index index can be combined. + virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0; + + // Create combined call of other ops in depth-th level. This usually involves concatenating + // or stacking inputs, then creating a new call. + virtual Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t parent_index) = 0; + + // Replace output of each branch with slices of the combined output. + virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap& subst_map) = 0; private: std::string op_name_; uint64_t min_num_branches_; ExprSubstMap subst_map_; + + void CombineBranches(const Group& branches); + + bool CheckLevel(const Group& branches, size_t depth, size_t parent_index); }; } // namespace relay From 08e9ccf4d15517445eec7983e56ac626995c9370 Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Fri, 30 Aug 2019 09:01:31 -0700 Subject: [PATCH 06/21] Clean up --- src/relay/pass/combine_parallel_conv2d.cc | 18 ++++++++++++++---- src/relay/pass/combine_parallel_dense.cc | 16 +++++++++++----- src/relay/pass/combine_parallel_op.cc | 16 ++++++++-------- src/relay/pass/combine_parallel_op.h | 21 +++++++++++++++------ 4 files changed, 48 insertions(+), 23 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index f59dc302f3c1..05bca5d8ea27 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -50,7 +50,8 @@ namespace relay { class ParallelConv2DCombiner : public ParallelOpCombiner { public: - ParallelConv2DCombiner(uint64_t min_num_branches) : ParallelOpCombiner("nn.conv2d", min_num_branches) { + explicit ParallelConv2DCombiner(uint64_t min_num_branches) + : ParallelOpCombiner("nn.conv2d", min_num_branches) { } protected: @@ -58,7 +59,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { return n->attrs.as()->groups == 1; } - bool AreCompatibleOps(const CallNode* a, const CallNode* b) { + bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { AttrsEqual eq; static const Layout kOIHW("OIHW"); const auto* attrs_a = a->attrs.as(); @@ -138,7 +139,10 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { return true; } - Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t parent_index) { + Call MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; size_t ndim = call->type_as()->shape.size(); @@ -148,19 +152,25 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { new_args.push_back(data); continue; } + size_t arg_ndim = call->args[i]->type_as()->shape.size(); size_t arg_channel_pos = channel_pos - ndim + arg_ndim; Array tuple; for (const auto& branch : branches) { tuple.push_back(branch[depth]->args[i]); } + auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos); new_args.push_back(std::move(concat)); } + return CallNode::make(call->op, new_args, call->attrs, {}); } - void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap& subst_map) { + void UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap& subst_map) { int64_t index = 0; for (const auto& branch : branches) { const CallNode* conv2d = branch[0]; diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index 08a02403fa9f..1a7551b2427e 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -49,7 +49,8 @@ namespace relay { class ParallelDenseCombiner : public ParallelOpCombiner { public: - ParallelDenseCombiner(uint64_t min_num_branches) : ParallelOpCombiner("nn.dense", min_num_branches) { + explicit ParallelDenseCombiner(uint64_t min_num_branches) + : ParallelOpCombiner("nn.dense", min_num_branches) { } protected: @@ -58,7 +59,7 @@ class ParallelDenseCombiner : public ParallelOpCombiner { return !attrs->units.defined(); } - bool AreCompatibleOps(const CallNode* a, const CallNode* b) { + bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { AttrsEqual eq; const auto* attrs_a = a->attrs.as(); const auto* attrs_b = b->attrs.as(); @@ -105,7 +106,10 @@ class ParallelDenseCombiner : public ParallelOpCombiner { return true; } - Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t parent_index) { + Call MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; @@ -135,11 +139,13 @@ class ParallelDenseCombiner : public ParallelOpCombiner { return CallNode::make(call->op, new_args, call->attrs, {}); } - void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap& subst_map) { + void UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap& subst_map) { int index = 0; auto split = MakeSplit(data, Integer(branches.size()), 0); for (const auto& branch : branches) { - const CallNode* dense = branch[0]; auto split_data = TupleGetItemNode::make(split, index++); auto squeezed_data = MakeSqueeze(split_data, {0}); subst_map[GetRef(branch[depth])] = squeezed_data; diff --git a/src/relay/pass/combine_parallel_op.cc b/src/relay/pass/combine_parallel_op.cc index 3cf9bda3fd96..717b76ceae2c 100644 --- a/src/relay/pass/combine_parallel_op.cc +++ b/src/relay/pass/combine_parallel_op.cc @@ -42,7 +42,7 @@ namespace relay { BranchGroupFinder::BranchGroupFinder(const std::string& op_name, FIsSupportedOp fis_supported_op, - FAreCompatibleOps fare_compatible_ops) + FAreCompatibleOps fare_compatible_ops) : op_name_(op_name), fis_supported_op_(fis_supported_op), fare_compatible_ops_(fare_compatible_ops) { @@ -110,18 +110,18 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) { } } -ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) +ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) : op_name_(op_name), min_num_branches_(min_num_branches) { } Expr ParallelOpCombiner::Combine(const Expr& expr) { auto groups = BranchGroupFinder(op_name_, - [&](const CallNode* n) { - return IsSupportedOp(n); + [&](const CallNode* n) { + return IsSupportedOp(n); }, - [&](const CallNode* a, const CallNode* b) { - return AreCompatibleOps(a, b); + [&](const CallNode* a, const CallNode* b) { + return CanOpsBeCombined(a, b); }).Find(expr); for (const Group& group : groups) { if (group.size() < min_num_branches_) { @@ -149,9 +149,9 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) { } CHECK_NE(parent_index, branches[0][i]->args.size()); if (!CheckLevel(branches, i, parent_index)) break; - combined = MakeCombinedCall(combined, branches, i, parent_index); + combined = MakeCombinedCallFromFollowingOps(combined, branches, i, parent_index); } - UpdateGroupOutput(combined, branches, i - 1, subst_map_); + UpdateGroupOutput(combined, branches, i - 1, std::move(subst_map_)); } bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index 19f424a90fbe..4db953998eb0 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -32,6 +32,8 @@ #include #include #include +#include +#include #include "./expr_subst.h" #include "./pattern_util.h" @@ -88,7 +90,7 @@ class ParallelOpCombiner { explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches); Expr Combine(const Expr& expr); - + protected: // Returns true if the op represented by CallNode n is supported to be the // root of a branch to be combined. Otherwise, returns false. @@ -96,7 +98,7 @@ class ParallelOpCombiner { // Returns true if ops represented by CallNodes a and b can be combined. // Otherwise, returns false. - virtual bool AreCompatibleOps(const CallNode* a, const CallNode* b) = 0; + virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0; // Create Call that consists of the combined ops. This usually involves concatenating // or stacking inputs, then creating a new call. @@ -105,12 +107,19 @@ class ParallelOpCombiner { // Returns true if arguments of a and b at index index can be combined. virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0; - // Create combined call of other ops in depth-th level. This usually involves concatenating - // or stacking inputs, then creating a new call. - virtual Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t parent_index) = 0; + // Create combined call of ops that follow initial combined op in depth-th level. + // This usually involves concatenating or stacking inputs, then creating a new call. + // Only called if IsArgCompatible returns true for each arg. + virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) = 0; // Replace output of each branch with slices of the combined output. - virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap& subst_map) = 0; + virtual void UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap& subst_map) = 0; private: std::string op_name_; From 737d1cd19cb12aa6d82404d9d65c1c4ce61a8f7c Mon Sep 17 00:00:00 2001 From: Jon Date: Fri, 30 Aug 2019 10:14:09 -0700 Subject: [PATCH 07/21] Fix --- src/relay/pass/combine_parallel_conv2d.cc | 4 ++-- src/relay/pass/combine_parallel_dense.cc | 6 +++--- src/relay/pass/combine_parallel_op.cc | 2 +- src/relay/pass/combine_parallel_op.h | 7 +++++-- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 05bca5d8ea27..736a297f6d22 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -170,7 +170,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, - ExprSubstMap& subst_map) { + ExprSubstMap* subst_map) { int64_t index = 0; for (const auto& branch : branches) { const CallNode* conv2d = branch[0]; @@ -185,7 +185,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { index += channels; end.push_back(index); auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array{}); - subst_map[GetRef(branch[depth])] = slice; + subst_map->insert({GetRef(branch[depth]), slice}); } } diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index 1a7551b2427e..f5929f5cca7b 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -49,7 +49,7 @@ namespace relay { class ParallelDenseCombiner : public ParallelOpCombiner { public: - explicit ParallelDenseCombiner(uint64_t min_num_branches) + explicit ParallelDenseCombiner(uint64_t min_num_branches) : ParallelOpCombiner("nn.dense", min_num_branches) { } @@ -142,13 +142,13 @@ class ParallelDenseCombiner : public ParallelOpCombiner { void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, - ExprSubstMap& subst_map) { + ExprSubstMap* subst_map) { int index = 0; auto split = MakeSplit(data, Integer(branches.size()), 0); for (const auto& branch : branches) { auto split_data = TupleGetItemNode::make(split, index++); auto squeezed_data = MakeSqueeze(split_data, {0}); - subst_map[GetRef(branch[depth])] = squeezed_data; + subst_map->insert({GetRef(branch[depth]), squeezed_data}); } } }; diff --git a/src/relay/pass/combine_parallel_op.cc b/src/relay/pass/combine_parallel_op.cc index 717b76ceae2c..25144f952446 100644 --- a/src/relay/pass/combine_parallel_op.cc +++ b/src/relay/pass/combine_parallel_op.cc @@ -151,7 +151,7 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) { if (!CheckLevel(branches, i, parent_index)) break; combined = MakeCombinedCallFromFollowingOps(combined, branches, i, parent_index); } - UpdateGroupOutput(combined, branches, i - 1, std::move(subst_map_)); + UpdateGroupOutput(combined, branches, i - 1, &subst_map_); } bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index 4db953998eb0..c59dd2c3267e 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -23,6 +23,8 @@ * \file combine_parallel_op.h * \brief Abstract class to combine parallel ops and their successive element-wise ops. */ +#ifndef TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_ +#define TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_ #include #include @@ -107,7 +109,7 @@ class ParallelOpCombiner { // Returns true if arguments of a and b at index index can be combined. virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0; - // Create combined call of ops that follow initial combined op in depth-th level. + // Create combined call of ops that follow initial combined op in depth-th level. // This usually involves concatenating or stacking inputs, then creating a new call. // Only called if IsArgCompatible returns true for each arg. virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, @@ -119,7 +121,7 @@ class ParallelOpCombiner { virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, - ExprSubstMap& subst_map) = 0; + ExprSubstMap* subst_map) = 0; private: std::string op_name_; @@ -133,3 +135,4 @@ class ParallelOpCombiner { } // namespace relay } // namespace tvm +#endif // TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_ From fc95cde7fda76db2c801d06302ceb1ca0af1e108 Mon Sep 17 00:00:00 2001 From: Jon Date: Fri, 30 Aug 2019 17:36:28 -0700 Subject: [PATCH 08/21] Remove statics --- src/relay/pass/combine_parallel_op.cc | 6 +++--- tests/python/relay/test_pass_combine_parallel_dense.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relay/pass/combine_parallel_op.cc b/src/relay/pass/combine_parallel_op.cc index 25144f952446..73b95e9577cf 100644 --- a/src/relay/pass/combine_parallel_op.cc +++ b/src/relay/pass/combine_parallel_op.cc @@ -49,7 +49,7 @@ BranchGroupFinder::BranchGroupFinder(const std::string& op_name, } std::vector BranchGroupFinder::Find(const Expr& expr) { - static const Op& op = Op::Get(op_name_); + const Op& op = Op::Get(op_name_); this->VisitExpr(expr); @@ -80,7 +80,7 @@ std::vector BranchGroupFinder::Find(const Expr& expr) { // Create a branch starting from op. Branch BranchGroupFinder::CreateBranch(const CallNode* op) { - static auto fpattern = Op::GetAttr("TOpPattern"); + auto fpattern = Op::GetAttr("TOpPattern"); // each branch has at least one element, the first element is always op Branch branch{op}; auto it = children_map_.find(GetRef(branch.back())); @@ -98,7 +98,7 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) { } void BranchGroupFinder::VisitExpr_(const CallNode* n) { - static const Op& op = Op::Get(op_name_); + const Op& op = Op::Get(op_name_); ExprVisitor::VisitExpr_(n); if (n->op.same_as(op) && fis_supported_op_(n)) { op_roots_.insert(n->args[0]); diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index 6a7fa15ce567..e5b7f62dcd34 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -31,7 +31,7 @@ def run_opt_pass(expr, opt_pass): def test_combine_parallel_dense(): - """Simple testcase. Three can be combined, either because of mismatched shapes or units""" + """Simple testcase. One dense cannot be combined because of mismatched shapes or units""" def before(x, w1, w2, w3, w4, units): args = [x, w1, w2, w3, w4] y1 = relay.nn.dense(x, w1) @@ -147,7 +147,7 @@ def check(i, j, k, is_2d_bias): check(100, 200, 300, True) def test_combine_parallel_dense_biasadd_scale_reshape(): - """Testcase of combining dense + 1d biasadd""" + """Testcase of combining dense + 1d biasadd + multiply with non-fused reshape""" def before(x, w1, w2, b1, b2, scale1, scale2, newshape): args = [x, w1, w2, b1, b2, scale1, scale2] y1 = relay.nn.dense(x, w1) From 8e85c858cb9527f2f53ee3fe1a94e84fdbaf62b8 Mon Sep 17 00:00:00 2001 From: Jon Date: Fri, 30 Aug 2019 17:41:07 -0700 Subject: [PATCH 09/21] fix wording --- src/relay/pass/combine_parallel_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/combine_parallel_op.cc b/src/relay/pass/combine_parallel_op.cc index 73b95e9577cf..35e5bff6d63c 100644 --- a/src/relay/pass/combine_parallel_op.cc +++ b/src/relay/pass/combine_parallel_op.cc @@ -141,7 +141,7 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) { }); size_t depth = it->size(); size_t i; - // starting from 1 to skip the dense + // starting from 1 to skip the op for (i = 1; i < depth; i++) { size_t parent_index; for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { From 132dd10278b5519a8610a8386ceec7279e0eb73d Mon Sep 17 00:00:00 2001 From: Jon Date: Tue, 3 Sep 2019 10:56:55 -0700 Subject: [PATCH 10/21] Start to add combine_parallel_op_batch --- src/relay/pass/combine_parallel_conv2d.cc | 4 +- src/relay/pass/combine_parallel_op_batch.cc | 206 ++++++++++++++++++++ 2 files changed, 208 insertions(+), 2 deletions(-) create mode 100644 src/relay/pass/combine_parallel_op_batch.cc diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 736a297f6d22..ef38dcae24c3 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -61,7 +61,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { AttrsEqual eq; - static const Layout kOIHW("OIHW"); + const Layout kOIHW("OIHW"); const auto* attrs_a = a->attrs.as(); const auto* attrs_b = b->attrs.as(); CHECK(attrs_a); @@ -83,7 +83,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { } Call MakeCombinedOp(const Group& branches) { - static const Op& conv2d = Op::Get("nn.conv2d"); + const Op& conv2d = Op::Get("nn.conv2d"); Expr data = branches[0][0]->args[0]; Expr new_weight; IndexExpr new_channels; diff --git a/src/relay/pass/combine_parallel_op_batch.cc b/src/relay/pass/combine_parallel_op_batch.cc new file mode 100644 index 000000000000..6df78c36b666 --- /dev/null +++ b/src/relay/pass/combine_parallel_op_batch.cc @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file combine_parallel_op_batch.cc + * \brief Combine parallel ops into a single batch op. + * + * This pass replaces ops that share the same input node and same shape + * with a single op that takes in batched input. The inputs of the new + * batched op are the stack of the original inputs. Elementwise and + * broadcast ops following the original op are also stacked + * and fused if possible. For example: + * + * data + * / \ + * add (2,2) add (2,2) + * | | + * elemwise (2,2) elemwise (2,2) + * | | + * + * Would become: + * +* data +* | +* add (2,2,2) +* | +* elemwise (2,2,2) +* / \ + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" +#include "./combine_parallel_op.h" + +namespace tvm { +namespace relay { + +class ParallelOpBatchCombiner : public ParallelOpCombiner { + public: + ParallelOpBatchCombiner(const std::string& op_name, uint64_t min_num_branches) + : ParallelOpBatchCombiner(op_name, op_name, min_num_branches) { + } + + ParallelOpBatchCombiner(const std::string& op_name, const std::string& batched_op_name, uint64_t min_num_branches) + : ParallelOpCombiner(op_name, min_num_branches), + batched_op_name_(batched_op_name) { + } + + protected: + bool IsSupportedOp(const CallNode* n) { + return true; + } + + bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { + if (!eq(a->args.size(), b->args.size())) { + return false; + } + + for (size_t i = 0; i < a->args.size(); i++) { + const auto* ta = a->args[i]->type_as(); + const auto* tb = b->args[i]->type_as(); + if (!eq(ta->shape.size(), tb->shape.size()) || !eq(ta->dtype, tb->dtype)) { + return false; + } + + for (size_t j = 0; j < ta->shape.size(); j++) { + if (!eq(ta->shape[j], tb->shape[j])) { + return false; + } + } + } + } + + Call MakeCombinedOp(const Group& branches) { + const Op& batch_op = Op::Get(batched_op_name_); + + Array stacked_args; + size_t num_args = branches[0][0]->args.size(); + for (size_t i = 0; i < num_args; i++) { + Array new_arg; + for (const auto& branch : branches) { + new_arg.push_back(branch[0]->args[i]); + } + + Expr new_arg_stack = MakeStack(TupleNode::make(new_arg), 0); + stacked_args.push_back(new_arg_stack); + } + + return CallNode::make(batch_op, stacked_args, Attrs(), {}); + } + + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { + AttrsEqual eq; + auto ta = a->args[index]->type_as(); + auto tb = b->args[index]->type_as(); + + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) + return false; + + for (size_t i = 0; i < ta->shape.size(); i++) { + if (!eq(ta->shape[i], tb->shape[i])) + return false; + } + return true; + } + + Call MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) { + Array new_args; + const CallNode* call = branches[0][depth]; + + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) { + new_args.push_back(data); + continue; + } + + Array tuple; + for (const auto& branch : branches) { + // if the shape of the arg is 1D, expand it to (1,j) so it can be properly broadcasted. + Expr arg = branch[depth]->args[i]; + const TensorTypeNode* arg_tensor = arg->type_as(); + if (arg_tensor->shape.size() == 1) { + Expr expanded_arg = MakeExpandDims(arg, 0, 1); + tuple.push_back(expanded_arg); + } else { + tuple.push_back(arg); + } + } + + auto stack = MakeStack(TupleNode::make(tuple), 0); + new_args.push_back(std::move(stack)); + } + + return CallNode::make(call->op, new_args, call->attrs, {}); + } + + void UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap* subst_map) { + int index = 0; + auto split = MakeSplit(data, Integer(branches.size()), 0); + for (const auto& branch : branches) { + auto split_data = TupleGetItemNode::make(split, index++); + auto squeezed_data = MakeSqueeze(split_data, {0}); + subst_map->insert({GetRef(branch[depth]), squeezed_data}); + } + } + +private: + std::string batched_op_name_; +}; + +/*! \brief Combine parallel dense if number of branches >= min_num_branches */ +Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) { + return ParallelDenseCombiner(min_num_branches).Combine(expr); +} + +namespace transform { + +Pass CombineParallelDense(uint64_t min_num_branches) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CombineParallelDense(f, min_num_branches)); + }; + return CreateFunctionPass(pass_func, 4, "CombineParallelDense", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CombineParallelDense") +.set_body_typed(CombineParallelDense); + +} // namespace transform + +} // namespace relay +} // namespace tvm From c4908b7eecd0587ac17174bb1e9e4ac0c5844845 Mon Sep 17 00:00:00 2001 From: Jon Date: Wed, 4 Sep 2019 11:39:55 -0700 Subject: [PATCH 11/21] Resolve PR comments --- python/tvm/relay/transform.py | 12 +- src/relay/pass/combine_parallel_dense.cc | 88 +------- src/relay/pass/combine_parallel_op.h | 89 ++++++-- src/relay/pass/combine_parallel_op_batch.cc | 214 ++++++++++---------- src/relay/pass/combine_parallel_op_batch.h | 76 +++++++ 5 files changed, 275 insertions(+), 204 deletions(-) create mode 100644 src/relay/pass/combine_parallel_op_batch.h diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 4562d3488afa..936785427648 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -402,7 +402,17 @@ def CombineParallelConv2D(min_num_branches=3): def CombineParallelDense(min_num_branches=3): - """Combine multiple dense operators into one. + """Combine multiple dense operators into one. For example: + + data + / \ + dense (2,2) dense (2,2) + + Would become: + + data + | + batch_matmul (2,2,2) Parameters ---------- diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index f5929f5cca7b..0caf3879960c 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -42,24 +42,24 @@ #include #include "./expr_subst.h" #include "./pattern_util.h" -#include "./combine_parallel_op.h" +#include "./combine_parallel_op_batch.h" namespace tvm { namespace relay { -class ParallelDenseCombiner : public ParallelOpCombiner { +class ParallelDenseCombiner : public ParallelOpBatchCombiner { public: explicit ParallelDenseCombiner(uint64_t min_num_branches) - : ParallelOpCombiner("nn.dense", min_num_branches) { + : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) { } protected: - bool IsSupportedOp(const CallNode* n) { + virtual bool IsSupportedOp(const CallNode* n) { const auto* attrs = n->attrs.as(); return !attrs->units.defined(); } - bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { + virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { AttrsEqual eq; const auto* attrs_a = a->attrs.as(); const auto* attrs_b = b->attrs.as(); @@ -73,84 +73,6 @@ class ParallelDenseCombiner : public ParallelOpCombiner { eq(weight_a->shape[1], weight_b->shape[1]) && eq(attrs_a->units.defined(), attrs_b->units.defined()); } - - Call MakeCombinedOp(const Group& branches) { - static const Op& batch_matmul = Op::Get("nn.batch_matmul"); - Array datas; - Array weights; - for (const auto& branch : branches) { - auto dense = branch[0]; - auto data = dense->args[0]; - auto weight = dense->args[1]; - datas.push_back(data); - weights.push_back(weight); - } - - Expr new_data = MakeStack(TupleNode::make(datas), 0); - Expr new_weight = MakeStack(TupleNode::make(weights), 0); - return CallNode::make(batch_matmul, {new_data, new_weight}, Attrs(), {}); - } - - bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { - AttrsEqual eq; - auto ta = a->args[index]->type_as(); - auto tb = b->args[index]->type_as(); - - if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) - return false; - - for (size_t i = 0; i < ta->shape.size(); i++) { - if (!eq(ta->shape[i], tb->shape[i])) - return false; - } - return true; - } - - Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, - size_t parent_index) { - Array new_args; - const CallNode* call = branches[0][depth]; - - for (size_t i = 0; i < call->args.size(); i++) { - if (i == parent_index) { - new_args.push_back(data); - continue; - } - - Array tuple; - for (const auto& branch : branches) { - // if the shape of the arg is 1D, expand it to (1,j) so it can be properly broadcasted. - Expr arg = branch[depth]->args[i]; - const TensorTypeNode* arg_tensor = arg->type_as(); - if (arg_tensor->shape.size() == 1) { - Expr expanded_arg = MakeExpandDims(arg, 0, 1); - tuple.push_back(expanded_arg); - } else { - tuple.push_back(arg); - } - } - - auto stack = MakeStack(TupleNode::make(tuple), 0); - new_args.push_back(std::move(stack)); - } - - return CallNode::make(call->op, new_args, call->attrs, {}); - } - - void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, - ExprSubstMap* subst_map) { - int index = 0; - auto split = MakeSplit(data, Integer(branches.size()), 0); - for (const auto& branch : branches) { - auto split_data = TupleGetItemNode::make(split, index++); - auto squeezed_data = MakeSqueeze(split_data, {0}); - subst_map->insert({GetRef(branch[depth]), squeezed_data}); - } - } }; /*! \brief Combine parallel dense if number of branches >= min_num_branches */ diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index c59dd2c3267e..65dc5d2a8371 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -50,9 +50,10 @@ using FAreCompatibleOps = std::function; /* - Class to find parallel branches starting with op as shown below and then - group branches by kernel shape and attributes of op. - Op can be followed by zero or more elemwise or broadcast ops. + Class to find parallel branches starting with op that are + grouped if they are able to be combined. + Op can be followed by zero or more elemwise or broadcast ops, + which are included in the group. Intermediate nodes have exactly one successor. It is possible that branches meet at a point, which should be handled in ParallelOpCombiner. @@ -65,10 +66,22 @@ using ExprSubstMap = std::unordered_map; */ class BranchGroupFinder : private ExprVisitor { public: + /* + @brief Constructor. + @param op_name name of op to start each group + @param fis_supported_op function that returns true if op + is supported for combining + @param fare_compatible_ops function that returns true if + two ops are compatible for combining + */ BranchGroupFinder(const std::string& op_name, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops); + /* + @brief Finds all groups that can be combined. + @return Vector of groups which can be combined. + */ std::vector Find(const Expr& expr); private: @@ -78,46 +91,90 @@ class BranchGroupFinder : private ExprVisitor { std::unordered_set op_roots_; std::unordered_map, NodeHash, NodeEqual> children_map_; - // Create a branch starting from op. Branch CreateBranch(const CallNode* op); void VisitExpr_(const CallNode* n) final; }; /* - Abstract class to find and combine parallel ops and the element-wise ops that follow. + Abstract class to find and combine parallel ops and the elementwise ops that follow. */ class ParallelOpCombiner { public: + /* + @brief Constructor. + @param op_name name of op to combine + @param min_num_branches min number of parallel branches beginning with op + to start combining + */ explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches); + /* + @brief Combines ops and following elementwise or broadcast ops + @param expr function to modify + @return new function with combined ops + */ Expr Combine(const Expr& expr); protected: - // Returns true if the op represented by CallNode n is supported to be the - // root of a branch to be combined. Otherwise, returns false. + /* + @brief Checks if node is supported to be combined + @param n node in question + @return True if the op represented by n is supported to be the root of a branch + to be combined. False otherwise. + */ virtual bool IsSupportedOp(const CallNode* n) = 0; - // Returns true if ops represented by CallNodes a and b can be combined. - // Otherwise, returns false. + /* + @brief Checks if two ops can be combined + @param a node a + @param b node b + @return True if a and b can be combined. False otherwise. + */ virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0; - // Create Call that consists of the combined ops. This usually involves concatenating - // or stacking inputs, then creating a new call. + /* + @brief Makes combined op from parallel ops in branches. This usually involves + concatenating or stacking inputs, then creating a new call. + @param branches branches that are to be combined + @return new call with branches combined. + */ virtual Call MakeCombinedOp(const Group& branches) = 0; - // Returns true if arguments of a and b at index index can be combined. + /* + @brief Checks if argument of op following combined ops are able to be combined + @param a node a + @param b node b + @param index index of argument in question + @return True if argument of a and b and index can be combined + */ virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0; - // Create combined call of ops that follow initial combined op in depth-th level. - // This usually involves concatenating or stacking inputs, then creating a new call. - // Only called if IsArgCompatible returns true for each arg. + /* + @brief Create combined call from ops that follow the initial combined op at the depth-th level. + This usually involves concatenating or stacking inputs, then creating a new call. + Only called if IsArgCompatbile returns true for each arg. + @param data combined op + @param branches branches of parallel ops to be combined + @param depth depth at which to combine ops + @param parent_index index of arg that corresponds to original input that was shared among + all combined ops + @return new combined call + */ virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) = 0; - // Replace output of each branch with slices of the combined output. + /* + @brief Updates map of expr to substitute with combined expr. This usually involves + slicing or splitting data. + @param data combined op + @param branches branches of parallel ops to be combined + @param depth depth at which to substitute + @param subst_map map of Expr to replace with Expr to replace it with + Replace output of each branch with slices of the combined output + */ virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, diff --git a/src/relay/pass/combine_parallel_op_batch.cc b/src/relay/pass/combine_parallel_op_batch.cc index 6df78c36b666..e0fe415daca0 100644 --- a/src/relay/pass/combine_parallel_op_batch.cc +++ b/src/relay/pass/combine_parallel_op_batch.cc @@ -38,12 +38,10 @@ * * Would become: * -* data -* | -* add (2,2,2) -* | -* elemwise (2,2,2) -* / \ + * data + * | + * add+elemwise (2,2,2) + * / \ * */ @@ -58,147 +56,155 @@ #include "./expr_subst.h" #include "./pattern_util.h" #include "./combine_parallel_op.h" +#include "./combine_parallel_op_batch.h" namespace tvm { namespace relay { -class ParallelOpBatchCombiner : public ParallelOpCombiner { - public: - ParallelOpBatchCombiner(const std::string& op_name, uint64_t min_num_branches) - : ParallelOpBatchCombiner(op_name, op_name, min_num_branches) { - } +ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name, + const std::string& batch_op_name, + uint64_t min_num_branches) + : ParallelOpCombiner(op_name, min_num_branches), + batch_op_name_(batch_op_name) { +} - ParallelOpBatchCombiner(const std::string& op_name, const std::string& batched_op_name, uint64_t min_num_branches) - : ParallelOpCombiner(op_name, min_num_branches), - batched_op_name_(batched_op_name) { - } +bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { + return true; +} - protected: - bool IsSupportedOp(const CallNode* n) { - return true; +bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) { + if (a->args.size() != b->args.size()) { + return false; } - bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { - if (!eq(a->args.size(), b->args.size())) { + AttrsEqual eq; + for (size_t i = 0; i < a->args.size(); i++) { + auto ta = a->args[i]->type_as(); + auto tb = b->args[i]->type_as(); + if (ta->shape.size() != tb->shape.size() || !eq(ta->dtype, tb->dtype)) { return false; } - for (size_t i = 0; i < a->args.size(); i++) { - const auto* ta = a->args[i]->type_as(); - const auto* tb = b->args[i]->type_as(); - if (!eq(ta->shape.size(), tb->shape.size()) || !eq(ta->dtype, tb->dtype)) { + for (size_t j = 0; j < ta->shape.size(); j++) { + if (!eq(ta->shape[j], tb->shape[j])) { return false; } - - for (size_t j = 0; j < ta->shape.size(); j++) { - if (!eq(ta->shape[j], tb->shape[j])) { - return false; - } - } } } - Call MakeCombinedOp(const Group& branches) { - const Op& batch_op = Op::Get(batched_op_name_); - - Array stacked_args; - size_t num_args = branches[0][0]->args.size(); - for (size_t i = 0; i < num_args; i++) { - Array new_arg; - for (const auto& branch : branches) { - new_arg.push_back(branch[0]->args[i]); - } + return true; +} - Expr new_arg_stack = MakeStack(TupleNode::make(new_arg), 0); - stacked_args.push_back(new_arg_stack); +Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) { + const Op& batch_op = Op::Get(batch_op_name_); + + Array new_args; + size_t num_args = branches[0][0]->args.size(); + for (size_t i = 0; i < num_args; i++) { + Array arg_from_all_branches; + for (const auto& branch : branches) { + arg_from_all_branches.push_back(branch[0]->args[i]); } - return CallNode::make(batch_op, stacked_args, Attrs(), {}); + new_args.push_back(MakeStack(TupleNode::make(arg_from_all_branches), 0)); } - bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { - AttrsEqual eq; - auto ta = a->args[index]->type_as(); - auto tb = b->args[index]->type_as(); + return CallNode::make(batch_op, new_args, Attrs(), {}); +} - if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) - return false; +bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { + AttrsEqual eq; + auto ta = a->args[index]->type_as(); + auto tb = b->args[index]->type_as(); - for (size_t i = 0; i < ta->shape.size(); i++) { - if (!eq(ta->shape[i], tb->shape[i])) - return false; - } - return true; + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) + return false; + + for (size_t i = 0; i < ta->shape.size(); i++) { + if (!eq(ta->shape[i], tb->shape[i])) + return false; } + return true; +} - Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, - size_t parent_index) { - Array new_args; - const CallNode* call = branches[0][depth]; - - for (size_t i = 0; i < call->args.size(); i++) { - if (i == parent_index) { - new_args.push_back(data); - continue; - } +Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) { + Array new_args; + const CallNode* call = branches[0][depth]; + const Op& bias_add = Op::Get("nn.bias_add"); + + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) { + new_args.push_back(data); + continue; + } - Array tuple; - for (const auto& branch : branches) { - // if the shape of the arg is 1D, expand it to (1,j) so it can be properly broadcasted. - Expr arg = branch[depth]->args[i]; - const TensorTypeNode* arg_tensor = arg->type_as(); - if (arg_tensor->shape.size() == 1) { - Expr expanded_arg = MakeExpandDims(arg, 0, 1); - tuple.push_back(expanded_arg); - } else { - tuple.push_back(arg); - } + Array tuple; + for (const auto& branch : branches) { + Expr arg = branch[depth]->args[i]; + const TensorTypeNode* arg_tensor = arg->type_as(); + + // special case for bias_add: 1D data needs to be expanded to (1,size) + // for proper broadcasting. + // + // note that this can't be applied generally, as elementwise ops + // such as full have a valid 1D input. + if (call->op.same_as(bias_add)) { + Expr expanded_arg = MakeExpandDims(arg, 0, 1); + tuple.push_back(expanded_arg); + } else { + tuple.push_back(arg); } - - auto stack = MakeStack(TupleNode::make(tuple), 0); - new_args.push_back(std::move(stack)); } - return CallNode::make(call->op, new_args, call->attrs, {}); + auto stack = MakeStack(TupleNode::make(tuple), 0); + new_args.push_back(std::move(stack)); } - void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, - ExprSubstMap* subst_map) { - int index = 0; - auto split = MakeSplit(data, Integer(branches.size()), 0); - for (const auto& branch : branches) { - auto split_data = TupleGetItemNode::make(split, index++); - auto squeezed_data = MakeSqueeze(split_data, {0}); - subst_map->insert({GetRef(branch[depth]), squeezed_data}); - } - } + return CallNode::make(call->op, new_args, call->attrs, {}); +} -private: - std::string batched_op_name_; -}; +void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap* subst_map) { + int index = 0; + auto split = MakeSplit(data, Integer(branches.size()), 0); + for (const auto& branch : branches) { + auto split_data = TupleGetItemNode::make(split, index++); + auto squeezed_data = MakeSqueeze(split_data, {0}); + subst_map->insert({GetRef(branch[depth]), squeezed_data}); + } +} -/*! \brief Combine parallel dense if number of branches >= min_num_branches */ -Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) { - return ParallelDenseCombiner(min_num_branches).Combine(expr); +/*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */ +Expr CombineParallelOpBatch(const Expr& expr, + const std::string& op_name, + const std::string& batch_op_name, + uint64_t min_num_branches) { + return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr); } namespace transform { -Pass CombineParallelDense(uint64_t min_num_branches) { +Pass CombineParallelOpBatch(const std::string& op_name, + const std::string& batch_op_name, + uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(CombineParallelDense(f, min_num_branches)); + return Downcast(CombineParallelOpBatch(f, + op_name, + batch_op_name, + min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelDense", + return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.CombineParallelDense") -.set_body_typed(CombineParallelDense); +TVM_REGISTER_API("relay._transform.CombineParallelOpBatch") +.set_body_typed(CombineParallelOpBatch); } // namespace transform diff --git a/src/relay/pass/combine_parallel_op_batch.h b/src/relay/pass/combine_parallel_op_batch.h new file mode 100644 index 000000000000..de9aaed76bab --- /dev/null +++ b/src/relay/pass/combine_parallel_op_batch.h @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file combine_parallel_op_batch.cc + * \brief Combine parallel ops into a single batch op. + */ +#ifndef TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_ +#define TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" +#include "./combine_parallel_op.h" + +namespace tvm { +namespace relay { + +class ParallelOpBatchCombiner : public ParallelOpCombiner { + public: + ParallelOpBatchCombiner(const std::string& op_name, + const std::string& batch_op_name, + uint64_t min_num_branches); + + protected: + virtual bool IsSupportedOp(const CallNode* n); + + virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b); + + Call MakeCombinedOp(const Group& branches) final; + + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final; + + Call MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) final; + + void UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap* subst_map) final; + + private: + std::string batch_op_name_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_ From ea3738a46edb59c80d25ed3bb63a1b40589f461c Mon Sep 17 00:00:00 2001 From: Jon Date: Wed, 4 Sep 2019 15:58:52 -0700 Subject: [PATCH 12/21] Resolve PR comments --- src/relay/pass/combine_parallel_conv2d.cc | 18 +++++++++--------- src/relay/pass/combine_parallel_op_batch.cc | 1 - 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index ef38dcae24c3..b461d83391f1 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -106,8 +106,8 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { const std::string& layout = new_attrs->out_layout == "" ? new_attrs->data_layout : new_attrs->out_layout; - channel_pos = layout.find('C'); - CHECK_NE(channel_pos, std::string::npos); + channel_pos_ = layout.find('C'); + CHECK_NE(channel_pos_, std::string::npos); return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); } @@ -123,12 +123,12 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { return false; // Position of the 'C' dimension in the argument - size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size(); + size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size(); // Channel super-dimension shoule be present and not broadcasted - if ((arg_channel_pos > channel_pos) || // size_t overflow - !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) || - !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos])) + if ((arg_channel_pos > channel_pos_) || // size_t overflow + !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos_]) || + !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos_])) return false; for (size_t i = 0; i < ta->shape.size(); i++) { @@ -154,7 +154,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { } size_t arg_ndim = call->args[i]->type_as()->shape.size(); - size_t arg_channel_pos = channel_pos - ndim + arg_ndim; + size_t arg_channel_pos = channel_pos_ - ndim + arg_ndim; Array tuple; for (const auto& branch : branches) { tuple.push_back(branch[depth]->args[i]); @@ -177,7 +177,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { int64_t channels = GetConv2DSuperChannelsDim(conv2d); Array begin; Array end; - for (size_t i = 0; i < channel_pos; i++) { + for (size_t i = 0; i < channel_pos_; i++) { begin.push_back(0); end.push_back(NullValue()); } @@ -190,7 +190,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { } private: - size_t channel_pos; + size_t channel_pos_; std::tuple TransformWeight(const Group& branches) { int64_t num_filters = 0; // number of filters of the transformed weight diff --git a/src/relay/pass/combine_parallel_op_batch.cc b/src/relay/pass/combine_parallel_op_batch.cc index e0fe415daca0..e8e23bf1d407 100644 --- a/src/relay/pass/combine_parallel_op_batch.cc +++ b/src/relay/pass/combine_parallel_op_batch.cc @@ -144,7 +144,6 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, Array tuple; for (const auto& branch : branches) { Expr arg = branch[depth]->args[i]; - const TensorTypeNode* arg_tensor = arg->type_as(); // special case for bias_add: 1D data needs to be expanded to (1,size) // for proper broadcasting. From 32998b437b67fddd165bf20009ea382bcd2be13d Mon Sep 17 00:00:00 2001 From: Jon Date: Fri, 6 Sep 2019 12:31:38 -0700 Subject: [PATCH 13/21] dummy change to retrigger CI --- src/relay/pass/combine_parallel_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index 65dc5d2a8371..b67939fe9edd 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -67,7 +67,7 @@ using ExprSubstMap = std::unordered_map; class BranchGroupFinder : private ExprVisitor { public: /* - @brief Constructor. + @brief Constructor @param op_name name of op to start each group @param fis_supported_op function that returns true if op is supported for combining From befc6a01febe4408244143d00262410c136b18c2 Mon Sep 17 00:00:00 2001 From: Jon Date: Mon, 9 Sep 2019 09:24:04 -0700 Subject: [PATCH 14/21] Change special case from bias_add to add --- src/relay/pass/combine_parallel_op_batch.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/combine_parallel_op_batch.cc b/src/relay/pass/combine_parallel_op_batch.cc index e8e23bf1d407..bb2603fdba15 100644 --- a/src/relay/pass/combine_parallel_op_batch.cc +++ b/src/relay/pass/combine_parallel_op_batch.cc @@ -133,7 +133,7 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; - const Op& bias_add = Op::Get("nn.bias_add"); + const Op& add = Op::Get("add"); for (size_t i = 0; i < call->args.size(); i++) { if (i == parent_index) { @@ -144,13 +144,14 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, Array tuple; for (const auto& branch : branches) { Expr arg = branch[depth]->args[i]; + const TensorTypeNode* arg_tensor = arg->type_as(); - // special case for bias_add: 1D data needs to be expanded to (1,size) + // special case for add: 1D data needs to be expanded to (1,size) // for proper broadcasting. // // note that this can't be applied generally, as elementwise ops // such as full have a valid 1D input. - if (call->op.same_as(bias_add)) { + if (call->op.same_as(add) && arg_tensor->shape.size() == 1) { Expr expanded_arg = MakeExpandDims(arg, 0, 1); tuple.push_back(expanded_arg); } else { From 36c43dee4a6fce4b93c1f4a1a77647351c7130f4 Mon Sep 17 00:00:00 2001 From: Jon Date: Mon, 9 Sep 2019 09:33:21 -0700 Subject: [PATCH 15/21] Revert special case change --- src/relay/pass/combine_parallel_op_batch.cc | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/relay/pass/combine_parallel_op_batch.cc b/src/relay/pass/combine_parallel_op_batch.cc index bb2603fdba15..235b230dfb31 100644 --- a/src/relay/pass/combine_parallel_op_batch.cc +++ b/src/relay/pass/combine_parallel_op_batch.cc @@ -133,7 +133,6 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; - const Op& add = Op::Get("add"); for (size_t i = 0; i < call->args.size(); i++) { if (i == parent_index) { @@ -143,15 +142,11 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, Array tuple; for (const auto& branch : branches) { + // if the shape of the arg is of shape (j,), + // expand it to (1,j) so it can be properly broadcasted. Expr arg = branch[depth]->args[i]; const TensorTypeNode* arg_tensor = arg->type_as(); - - // special case for add: 1D data needs to be expanded to (1,size) - // for proper broadcasting. - // - // note that this can't be applied generally, as elementwise ops - // such as full have a valid 1D input. - if (call->op.same_as(add) && arg_tensor->shape.size() == 1) { + if (arg_tensor->shape.size() == 1) { Expr expanded_arg = MakeExpandDims(arg, 0, 1); tuple.push_back(expanded_arg); } else { From 797e8326396c6e2851dab61e3167cf842fe34f77 Mon Sep 17 00:00:00 2001 From: Jon Date: Tue, 10 Sep 2019 09:53:18 -0700 Subject: [PATCH 16/21] Ignore units check --- src/relay/pass/combine_parallel_dense.cc | 8 +--- .../relay/test_pass_combine_parallel_dense.py | 38 ++++++------------- 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index 0caf3879960c..7b00fef9bd36 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -54,11 +54,6 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner { } protected: - virtual bool IsSupportedOp(const CallNode* n) { - const auto* attrs = n->attrs.as(); - return !attrs->units.defined(); - } - virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { AttrsEqual eq; const auto* attrs_a = a->attrs.as(); @@ -70,8 +65,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner { return eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(weight_a->shape[0], weight_b->shape[0]) && - eq(weight_a->shape[1], weight_b->shape[1]) && - eq(attrs_a->units.defined(), attrs_b->units.defined()); + eq(weight_a->shape[1], weight_b->shape[1]); } }; diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index e5b7f62dcd34..8c8242786196 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -31,23 +31,20 @@ def run_opt_pass(expr, opt_pass): def test_combine_parallel_dense(): - """Simple testcase. One dense cannot be combined because of mismatched shapes or units""" - def before(x, w1, w2, w3, w4, units): + """Simple testcase. One dense cannot be combined because of mismatched shapes""" + def before(x, w1, w2, w3, w4): args = [x, w1, w2, w3, w4] y1 = relay.nn.dense(x, w1) y2 = relay.nn.dense(x, w2) # y3 cannot be combined - if units == -1: - y3 = relay.nn.dense(x, w3) - else: - y3 = relay.nn.dense(x, w3, units=units) + y3 = relay.nn.dense(x, w3) y4 = relay.nn.dense(x, w4) y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) - def expected(x, w1, w2, w3, w4, units): + def expected(x, w1, w2, w3, w4): # use a fixed order of args so alpha equal check can pass args = [x, w1, w2, w3, w4] x_stacked = relay.stack((x, x, x), axis=0) @@ -58,39 +55,28 @@ def expected(x, w1, w2, w3, w4, units): y2 = relay.squeeze(y2, [0]) y4 = relay.squeeze(y4, [0]) - if units == -1: - y3 = relay.nn.dense(x, w3) - else: - y3 = relay.nn.dense(x, w3, units=units) + # y3 cannot be combined + y3 = relay.nn.dense(x, w3) y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) - def check(i, j, k, use_units): + def check(i, j, k): x = relay.var("x", shape=(i, k)) w1 = relay.var("w1", shape=(j, k)) w2 = relay.var("w2", shape=(j, k)) - - if use_units: - units = j - w3 = relay.var("w3", shape=(j, k)) - else: - units = -1 - w3 = relay.var("w3", shape=(j + 1, k)) - + w3 = relay.var("w3", shape=(j + 1, k)) w4 = relay.var("w4", shape=(j, k)) - y_before = before(x, w1, w2, w3, w4, units) + y_before = before(x, w1, w2, w3, w4) y = run_opt_pass(y_before, transform.CombineParallelDense(min_num_branches=2)) - y_expected = expected(x, w1, w2, w3, w4, units) + y_expected = expected(x, w1, w2, w3, w4) y_expected = run_opt_pass(y_expected, transform.InferType()) assert relay.analysis.alpha_equal(y, y_expected) - check(3, 5, 4, False) - check(100, 200, 300, False) - check(3, 5, 4, True) - check(100, 200, 300, True) + check(3, 5, 4) + check(100, 200, 300) def test_combine_parallel_dense_biasadd(): From b2bfad95783b7b25497ce129966b520fe4b851d5 Mon Sep 17 00:00:00 2001 From: Jon Date: Tue, 10 Sep 2019 12:59:22 -0700 Subject: [PATCH 17/21] dummy change to retrigger CI --- tests/python/relay/test_pass_combine_parallel_dense.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index 8c8242786196..f04dd3606aec 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -31,7 +31,7 @@ def run_opt_pass(expr, opt_pass): def test_combine_parallel_dense(): - """Simple testcase. One dense cannot be combined because of mismatched shapes""" + """Simple testcase. One dense cannot be combined because of shape mismatch""" def before(x, w1, w2, w3, w4): args = [x, w1, w2, w3, w4] y1 = relay.nn.dense(x, w1) From fe61ed934a165dcfac73bbfca3f6a99a6fecb1e3 Mon Sep 17 00:00:00 2001 From: Jon Date: Tue, 10 Sep 2019 18:19:26 -0700 Subject: [PATCH 18/21] dummy change to re-trigger CI --- tests/python/relay/test_pass_combine_parallel_dense.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index f04dd3606aec..070ab8658b88 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -31,7 +31,7 @@ def run_opt_pass(expr, opt_pass): def test_combine_parallel_dense(): - """Simple testcase. One dense cannot be combined because of shape mismatch""" + """Simple testcase. One dense cannot be combined due to shape mismatch""" def before(x, w1, w2, w3, w4): args = [x, w1, w2, w3, w4] y1 = relay.nn.dense(x, w1) From 8405517fd5fecd52ff346a883e57e3546357917c Mon Sep 17 00:00:00 2001 From: Jon Date: Thu, 12 Sep 2019 15:20:22 -0700 Subject: [PATCH 19/21] Improve docs --- python/tvm/relay/transform.py | 10 +- src/relay/pass/combine_parallel_op.h | 122 ++++++++++++++------- src/relay/pass/combine_parallel_op_batch.h | 73 ++++++++++++ 3 files changed, 161 insertions(+), 44 deletions(-) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 936785427648..58bf17efd387 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -404,15 +404,17 @@ def CombineParallelConv2D(min_num_branches=3): def CombineParallelDense(min_num_branches=3): """Combine multiple dense operators into one. For example: - data - / \ - dense (2,2) dense (2,2) + data + / \ + dense (2,2) dense (2,2) + | | + elemwise/bcast (2,2) elemwise/bcast (2,2) Would become: data | - batch_matmul (2,2,2) + batch_matmul+elemwise/bcast (2,2,2) Parameters ---------- diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index b67939fe9edd..9eeb050a2d44 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -51,7 +51,8 @@ using ExprSubstMap = std::unordered_map; /* Class to find parallel branches starting with op that are - grouped if they are able to be combined. + grouped if they are able to be combined. They are eligible to + be combined if they have the same input data. Op can be followed by zero or more elemwise or broadcast ops, which are included in the group. Intermediate nodes have exactly one successor. It is possible that branches meet at a point, @@ -67,11 +68,11 @@ using ExprSubstMap = std::unordered_map; class BranchGroupFinder : private ExprVisitor { public: /* - @brief Constructor - @param op_name name of op to start each group - @param fis_supported_op function that returns true if op + \brief Constructor + \param op_name name of op to start each group + \param fis_supported_op function that returns true if op is supported for combining - @param fare_compatible_ops function that returns true if + \param fare_compatible_ops function that returns true if two ops are compatible for combining */ BranchGroupFinder(const std::string& op_name, @@ -79,20 +80,43 @@ class BranchGroupFinder : private ExprVisitor { FAreCompatibleOps fare_compatible_ops); /* - @brief Finds all groups that can be combined. - @return Vector of groups which can be combined. + \brief Finds all groups that can be combined. + \return Vector of groups which can be combined. */ std::vector Find(const Expr& expr); private: + /* name of op to find parallel branches for */ std::string op_name_; + + /* function to return true if op is eligible to be combined, + false otherwise + */ FIsSupportedOp fis_supported_op_; + + /* function to return true if two parallel ops are eligible + to be combined, false otherwise + */ FAreCompatibleOps fare_compatible_ops_; + + /* ops that are on the first (logically, leftmost) branch + of parallel ops and are eligible to be combined + */ std::unordered_set op_roots_; + + /* map of Expr to CallNodes that follow it */ std::unordered_map, NodeHash, NodeEqual> children_map_; + /* + \brief Creates new branch from op and its children that have + elementwise or broadcast patterns + \return New branch + */ Branch CreateBranch(const CallNode* op); + /* + \brief Expression visitor function + */ void VisitExpr_(const CallNode* n) final; }; @@ -102,64 +126,64 @@ class BranchGroupFinder : private ExprVisitor { class ParallelOpCombiner { public: /* - @brief Constructor. - @param op_name name of op to combine - @param min_num_branches min number of parallel branches beginning with op + \brief Constructor. + \param op_name name of op to combine + \param min_num_branches min number of parallel branches beginning with op to start combining */ explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches); /* - @brief Combines ops and following elementwise or broadcast ops - @param expr function to modify - @return new function with combined ops + \brief Combines ops and following elementwise or broadcast ops + \param expr function to modify + \return new function with combined ops */ Expr Combine(const Expr& expr); protected: /* - @brief Checks if node is supported to be combined - @param n node in question - @return True if the op represented by n is supported to be the root of a branch + \brief Checks if node is supported to be combined + \param n node in question + \return True if the op represented by n is supported to be the root of a branch to be combined. False otherwise. */ virtual bool IsSupportedOp(const CallNode* n) = 0; /* - @brief Checks if two ops can be combined - @param a node a - @param b node b - @return True if a and b can be combined. False otherwise. + \brief Checks if two ops can be combined + \param a node a + \param b node b + \return True if a and b can be combined. False otherwise. */ virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0; /* - @brief Makes combined op from parallel ops in branches. This usually involves + \brief Makes combined op from parallel ops in branches. This usually involves concatenating or stacking inputs, then creating a new call. - @param branches branches that are to be combined - @return new call with branches combined. + \param branches branches that are to be combined + \return new call with branches combined. */ virtual Call MakeCombinedOp(const Group& branches) = 0; /* - @brief Checks if argument of op following combined ops are able to be combined - @param a node a - @param b node b - @param index index of argument in question - @return True if argument of a and b and index can be combined + \brief Checks if argument of op following combined ops are able to be combined + \param a node a + \param b node b + \param index index of argument in question + \return True if argument of a and b and index can be combined */ virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0; /* - @brief Create combined call from ops that follow the initial combined op at the depth-th level. + \brief Create combined call from ops that follow the initial combined op at the depth-th level. This usually involves concatenating or stacking inputs, then creating a new call. Only called if IsArgCompatbile returns true for each arg. - @param data combined op - @param branches branches of parallel ops to be combined - @param depth depth at which to combine ops - @param parent_index index of arg that corresponds to original input that was shared among + \param data combined op + \param branches branches of parallel ops to be combined + \param depth depth at which to combine ops + \param parent_index index of arg that corresponds to original input that was shared among all combined ops - @return new combined call + \return new combined call */ virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, @@ -167,13 +191,12 @@ class ParallelOpCombiner { size_t parent_index) = 0; /* - @brief Updates map of expr to substitute with combined expr. This usually involves + \brief Updates map of expr to substitute with combined expr. This usually involves slicing or splitting data. - @param data combined op - @param branches branches of parallel ops to be combined - @param depth depth at which to substitute - @param subst_map map of Expr to replace with Expr to replace it with - Replace output of each branch with slices of the combined output + \param data combined op + \param branches branches of parallel ops to be combined + \param depth depth at which to substitute + \param subst_map map of Expr to replace with Expr to replace it with */ virtual void UpdateGroupOutput(const Expr& data, const Group& branches, @@ -181,12 +204,31 @@ class ParallelOpCombiner { ExprSubstMap* subst_map) = 0; private: + /* name of op to be combined */ std::string op_name_; + + /* minimum number of parallel branches to combine */ uint64_t min_num_branches_; + + /* map of Expr to Expr to substitute it with after running pass */ ExprSubstMap subst_map_; + /* + \brief Combine parallel branches and updates subst_map_ with Exprs + to be substituted + \param branches branches to be combined + */ void CombineBranches(const Group& branches); + /* + \brief Combine parallel branches and updates subst_map_ with Exprs + to be substituted + \param branches parallel branches to potentially be combined + \param depth depth at which to look at op + \param parent_index index of arg that corresponds to original input that was shared among + all combined ops + \return true if parallel ops at depth can be combined, false otherwise + */ bool CheckLevel(const Group& branches, size_t depth, size_t parent_index); }; diff --git a/src/relay/pass/combine_parallel_op_batch.h b/src/relay/pass/combine_parallel_op_batch.h index de9aaed76bab..56ba1ced23a4 100644 --- a/src/relay/pass/combine_parallel_op_batch.h +++ b/src/relay/pass/combine_parallel_op_batch.h @@ -41,32 +41,105 @@ namespace tvm { namespace relay { +/* + Class to find and combine parallel ops and following element-wise + and broadcast ops into a single batch op. Ops can be combined + if they have the same input data. Batch op is formed by + stacking inputs. Final results are retrieved by splitting output. + For example: + + data + / \ + dense (2,2) dense (2,2) + | | + elemwise/bcast (2,2) elemwise/bcast (2,2) + + Would become: + + data + | + batch_matmul+elemwise/bcast (2,2,2) +*/ class ParallelOpBatchCombiner : public ParallelOpCombiner { public: + /* + \brief Constructor. + \param op_name name of op to combine + \param batch_op_name name of op that combined branches will be joined into + \param min_num_branches min number of parallel branches beginning with op + to start combining + */ ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches); protected: + /* + \brief Checks if node is supported to be combined + \param n node in question + \return True by default + */ virtual bool IsSupportedOp(const CallNode* n); + /* + \brief Checks if two ops can be combined + \param a node a + \param b node b + \return True if shapes and dtypes of all args of a and b are the same + */ virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b); + /* + \brief Makes combined op from parallel ops in branches. This usually involves + concatenating or stacking inputs, then creating a new call. + \param branches branches that are to be combined + \return new call with branches combined as batch op by stacking args + */ Call MakeCombinedOp(const Group& branches) final; + /* + \brief Checks if argument of op following combined ops are able to be combined + \param a node a + \param b node b + \param index index of argument in question + \return True if shapes and dtypes of args[index] a and b are the same + */ bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final; + /* + \brief Create combined call from ops that follow the initial combined op at the depth-th level. + This usually involves concatenating or stacking inputs, then creating a new call. + Only called if IsArgCompatbile returns true for each arg. + \param data combined op + \param branches branches of parallel ops to be combined + \param depth depth at which to combine ops + \param parent_index index of arg that corresponds to original input that was shared among + all combined ops + \return new combined call as batch op by stacking args + */ Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) final; + /* + \brief Updates map of expr to substitute with combined expr. This usually involves + slicing or splitting data. + \param data combined op + \param branches branches of parallel ops to be combined + \param depth depth at which to substitute + \param subst_map map of Expr to replace with Expr to replace it with + */ void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) final; private: + /* name of op to replace combined ops with. for example, + for combining parallel dense, this will will be set to + nn.batch_matmul + */ std::string batch_op_name_; }; From 17d17b52d6935f404bb3d8af406d6352ac9a0d03 Mon Sep 17 00:00:00 2001 From: Jon Date: Fri, 13 Sep 2019 13:44:33 -0700 Subject: [PATCH 20/21] Update docs --- src/relay/pass/combine_parallel_conv2d.cc | 1 + src/relay/pass/combine_parallel_op.h | 24 ++++++++++++---------- src/relay/pass/combine_parallel_op_batch.h | 6 +++--- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index b461d83391f1..bc9685f815cb 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -190,6 +190,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { } private: + /* \brief index of channel dimension */ size_t channel_pos_; std::tuple TransformWeight(const Group& branches) { diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index 9eeb050a2d44..b9d770ed87a6 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -81,30 +81,32 @@ class BranchGroupFinder : private ExprVisitor { /* \brief Finds all groups that can be combined. + \param expr Relay expression that represents function + to look at for groups to be combined \return Vector of groups which can be combined. */ std::vector Find(const Expr& expr); private: - /* name of op to find parallel branches for */ + /* \brief name of op to find parallel branches for */ std::string op_name_; - /* function to return true if op is eligible to be combined, - false otherwise + /* \brief function to return true if op is eligible to be combined, + false otherwise */ FIsSupportedOp fis_supported_op_; - /* function to return true if two parallel ops are eligible - to be combined, false otherwise + /* \brief function to return true if two parallel ops are eligible + to be combined, false otherwise */ FAreCompatibleOps fare_compatible_ops_; - /* ops that are on the first (logically, leftmost) branch - of parallel ops and are eligible to be combined + /* \brief ops that are on the first (logically, leftmost) branch + of parallel ops and are eligible to be combined */ std::unordered_set op_roots_; - /* map of Expr to CallNodes that follow it */ + /* \brief map of Expr to CallNodes that follow it */ std::unordered_map, NodeHash, NodeEqual> children_map_; /* @@ -204,13 +206,13 @@ class ParallelOpCombiner { ExprSubstMap* subst_map) = 0; private: - /* name of op to be combined */ + /* \brief name of op to be combined */ std::string op_name_; - /* minimum number of parallel branches to combine */ + /* \brief minimum number of parallel branches to combine */ uint64_t min_num_branches_; - /* map of Expr to Expr to substitute it with after running pass */ + /* \brief map of Expr to Expr to substitute it with after running pass */ ExprSubstMap subst_map_; /* diff --git a/src/relay/pass/combine_parallel_op_batch.h b/src/relay/pass/combine_parallel_op_batch.h index 56ba1ced23a4..4c5aab595dbc 100644 --- a/src/relay/pass/combine_parallel_op_batch.h +++ b/src/relay/pass/combine_parallel_op_batch.h @@ -136,9 +136,9 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { ExprSubstMap* subst_map) final; private: - /* name of op to replace combined ops with. for example, - for combining parallel dense, this will will be set to - nn.batch_matmul + /* \brief name of op to replace combined ops with. for example, + for combining parallel dense, this will will be set to + nn.batch_matmul */ std::string batch_op_name_; }; From b86854d27de7e1f53a5875d3c500b3d66cfaa71a Mon Sep 17 00:00:00 2001 From: Jon Date: Mon, 16 Sep 2019 11:26:37 -0700 Subject: [PATCH 21/21] Update docs --- src/relay/pass/combine_parallel_op.h | 166 ++++++++++----------- src/relay/pass/combine_parallel_op_batch.h | 114 +++++++------- 2 files changed, 140 insertions(+), 140 deletions(-) diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index b9d770ed87a6..756dba98a707 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -50,40 +50,40 @@ using FAreCompatibleOps = std::function; /* - Class to find parallel branches starting with op that are - grouped if they are able to be combined. They are eligible to - be combined if they have the same input data. - Op can be followed by zero or more elemwise or broadcast ops, - which are included in the group. - Intermediate nodes have exactly one successor. It is possible that branches meet at a point, - which should be handled in ParallelOpCombiner. - - data - / \ - op op - | | - elem-wise elem-wise - | | -*/ + * Class to find parallel branches starting with op that are + * grouped if they are able to be combined. They are eligible to + * be combined if they have the same input data. + * Op can be followed by zero or more elemwise or broadcast ops, + * which are included in the group. + * Intermediate nodes have exactly one successor. It is possible that branches meet at a point, + * which should be handled in ParallelOpCombiner. + * + * data + * / \ + * op op + * | | + * elem-wise elem-wise + * | | + */ class BranchGroupFinder : private ExprVisitor { public: /* - \brief Constructor - \param op_name name of op to start each group - \param fis_supported_op function that returns true if op - is supported for combining - \param fare_compatible_ops function that returns true if - two ops are compatible for combining + * \brief Constructor + * \param op_name name of op to start each group + * \param fis_supported_op function that returns true if op + * is supported for combining + * \param fare_compatible_ops function that returns true if + * two ops are compatible for combining */ BranchGroupFinder(const std::string& op_name, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops); /* - \brief Finds all groups that can be combined. - \param expr Relay expression that represents function - to look at for groups to be combined - \return Vector of groups which can be combined. + * \brief Finds all groups that can be combined. + * \param expr Relay expression that represents function + * to look at for groups to be combined + * \return Vector of groups which can be combined. */ std::vector Find(const Expr& expr); @@ -92,17 +92,17 @@ class BranchGroupFinder : private ExprVisitor { std::string op_name_; /* \brief function to return true if op is eligible to be combined, - false otherwise + * false otherwise */ FIsSupportedOp fis_supported_op_; /* \brief function to return true if two parallel ops are eligible - to be combined, false otherwise + * to be combined, false otherwise */ FAreCompatibleOps fare_compatible_ops_; /* \brief ops that are on the first (logically, leftmost) branch - of parallel ops and are eligible to be combined + * of parallel ops and are eligible to be combined */ std::unordered_set op_roots_; @@ -110,82 +110,82 @@ class BranchGroupFinder : private ExprVisitor { std::unordered_map, NodeHash, NodeEqual> children_map_; /* - \brief Creates new branch from op and its children that have - elementwise or broadcast patterns - \return New branch + * \brief Creates new branch from op and its children that have + * elementwise or broadcast patterns + * \return New branch */ Branch CreateBranch(const CallNode* op); /* - \brief Expression visitor function + * \brief Expression visitor function */ void VisitExpr_(const CallNode* n) final; }; /* - Abstract class to find and combine parallel ops and the elementwise ops that follow. -*/ + * Abstract class to find and combine parallel ops and the elementwise ops that follow. + */ class ParallelOpCombiner { public: /* - \brief Constructor. - \param op_name name of op to combine - \param min_num_branches min number of parallel branches beginning with op - to start combining + * \brief Constructor. + * \param op_name name of op to combine + * \param min_num_branches min number of parallel branches beginning with op + * to start combining */ explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches); /* - \brief Combines ops and following elementwise or broadcast ops - \param expr function to modify - \return new function with combined ops + * \brief Combines ops and following elementwise or broadcast ops + * \param expr function to modify + * \return new function with combined ops */ Expr Combine(const Expr& expr); protected: /* - \brief Checks if node is supported to be combined - \param n node in question - \return True if the op represented by n is supported to be the root of a branch - to be combined. False otherwise. + * \brief Checks if node is supported to be combined + * \param n node in question + * \return True if the op represented by n is supported to be the root of a branch + * to be combined. False otherwise. */ virtual bool IsSupportedOp(const CallNode* n) = 0; /* - \brief Checks if two ops can be combined - \param a node a - \param b node b - \return True if a and b can be combined. False otherwise. + * \brief Checks if two ops can be combined + * \param a node a + * \param b node b + * \return True if a and b can be combined. False otherwise. */ virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0; /* - \brief Makes combined op from parallel ops in branches. This usually involves - concatenating or stacking inputs, then creating a new call. - \param branches branches that are to be combined - \return new call with branches combined. + * \brief Makes combined op from parallel ops in branches. This usually involves + * concatenating or stacking inputs, then creating a new call. + * \param branches branches that are to be combined + * \return new call with branches combined. */ virtual Call MakeCombinedOp(const Group& branches) = 0; /* - \brief Checks if argument of op following combined ops are able to be combined - \param a node a - \param b node b - \param index index of argument in question - \return True if argument of a and b and index can be combined + * \brief Checks if argument of op following combined ops are able to be combined + * \param a node a + * \param b node b + * \param index index of argument in question + * \return True if argument of a and b and index can be combined */ virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0; /* - \brief Create combined call from ops that follow the initial combined op at the depth-th level. - This usually involves concatenating or stacking inputs, then creating a new call. - Only called if IsArgCompatbile returns true for each arg. - \param data combined op - \param branches branches of parallel ops to be combined - \param depth depth at which to combine ops - \param parent_index index of arg that corresponds to original input that was shared among - all combined ops - \return new combined call + * \brief Create combined call from ops that follow the initial combined op at the depth-th level. + * This usually involves concatenating or stacking inputs, then creating a new call. + * Only called if IsArgCompatbile returns true for each arg. + * \param data combined op + * \param branches branches of parallel ops to be combined + * \param depth depth at which to combine ops + * \param parent_index index of arg that corresponds to original input that was shared among + * all combined ops + * \return new combined call */ virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, @@ -193,12 +193,12 @@ class ParallelOpCombiner { size_t parent_index) = 0; /* - \brief Updates map of expr to substitute with combined expr. This usually involves - slicing or splitting data. - \param data combined op - \param branches branches of parallel ops to be combined - \param depth depth at which to substitute - \param subst_map map of Expr to replace with Expr to replace it with + * \brief Updates map of expr to substitute with combined expr. This usually involves + * slicing or splitting data. + * \param data combined op + * \param branches branches of parallel ops to be combined + * \param depth depth at which to substitute + * \param subst_map map of Expr to replace with Expr to replace it with */ virtual void UpdateGroupOutput(const Expr& data, const Group& branches, @@ -216,20 +216,20 @@ class ParallelOpCombiner { ExprSubstMap subst_map_; /* - \brief Combine parallel branches and updates subst_map_ with Exprs - to be substituted - \param branches branches to be combined + * \brief Combine parallel branches and updates subst_map_ with Exprs + * to be substituted + * \param branches branches to be combined */ void CombineBranches(const Group& branches); /* - \brief Combine parallel branches and updates subst_map_ with Exprs - to be substituted - \param branches parallel branches to potentially be combined - \param depth depth at which to look at op - \param parent_index index of arg that corresponds to original input that was shared among - all combined ops - \return true if parallel ops at depth can be combined, false otherwise + * \brief Combine parallel branches and updates subst_map_ with Exprs + * to be substituted + * \param branches parallel branches to potentially be combined + * \param depth depth at which to look at op + * \param parent_index index of arg that corresponds to original input that was shared among + * all combined ops + * \return true if parallel ops at depth can be combined, false otherwise */ bool CheckLevel(const Group& branches, size_t depth, size_t parent_index); }; diff --git a/src/relay/pass/combine_parallel_op_batch.h b/src/relay/pass/combine_parallel_op_batch.h index 4c5aab595dbc..84ef8d353985 100644 --- a/src/relay/pass/combine_parallel_op_batch.h +++ b/src/relay/pass/combine_parallel_op_batch.h @@ -42,32 +42,32 @@ namespace tvm { namespace relay { /* - Class to find and combine parallel ops and following element-wise - and broadcast ops into a single batch op. Ops can be combined - if they have the same input data. Batch op is formed by - stacking inputs. Final results are retrieved by splitting output. - For example: - - data - / \ - dense (2,2) dense (2,2) - | | - elemwise/bcast (2,2) elemwise/bcast (2,2) - - Would become: - - data - | - batch_matmul+elemwise/bcast (2,2,2) -*/ + * Class to find and combine parallel ops and following element-wise + * and broadcast ops into a single batch op. Ops can be combined + * if they have the same input data. Batch op is formed by + * stacking inputs. Final results are retrieved by splitting output. + * For example: + * + * data + * / \ + * dense (2,2) dense (2,2) + * | | + * elemwise/bcast (2,2) elemwise/bcast (2,2) + * + * Would become: + * + * data + * | + * batch_matmul+elemwise/bcast (2,2,2) + */ class ParallelOpBatchCombiner : public ParallelOpCombiner { public: /* - \brief Constructor. - \param op_name name of op to combine - \param batch_op_name name of op that combined branches will be joined into - \param min_num_branches min number of parallel branches beginning with op - to start combining + * \brief Constructor. + * \param op_name name of op to combine + * \param batch_op_name name of op that combined branches will be joined into + * \param min_num_branches min number of parallel branches beginning with op + * to start combining */ ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name, @@ -75,47 +75,47 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { protected: /* - \brief Checks if node is supported to be combined - \param n node in question - \return True by default + * \brief Checks if node is supported to be combined + * \param n node in question + * \return True by default */ virtual bool IsSupportedOp(const CallNode* n); /* - \brief Checks if two ops can be combined - \param a node a - \param b node b - \return True if shapes and dtypes of all args of a and b are the same + * \brief Checks if two ops can be combined + * \param a node a + * \param b node b + * \return True if shapes and dtypes of all args of a and b are the same */ virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b); /* - \brief Makes combined op from parallel ops in branches. This usually involves - concatenating or stacking inputs, then creating a new call. - \param branches branches that are to be combined - \return new call with branches combined as batch op by stacking args + * \brief Makes combined op from parallel ops in branches. This usually involves + * concatenating or stacking inputs, then creating a new call. + * \param branches branches that are to be combined + * \return new call with branches combined as batch op by stacking args */ Call MakeCombinedOp(const Group& branches) final; /* - \brief Checks if argument of op following combined ops are able to be combined - \param a node a - \param b node b - \param index index of argument in question - \return True if shapes and dtypes of args[index] a and b are the same + * \brief Checks if argument of op following combined ops are able to be combined + * \param a node a + * \param b node b + * \param index index of argument in question + * \return True if shapes and dtypes of args[index] a and b are the same */ bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final; /* - \brief Create combined call from ops that follow the initial combined op at the depth-th level. - This usually involves concatenating or stacking inputs, then creating a new call. - Only called if IsArgCompatbile returns true for each arg. - \param data combined op - \param branches branches of parallel ops to be combined - \param depth depth at which to combine ops - \param parent_index index of arg that corresponds to original input that was shared among - all combined ops - \return new combined call as batch op by stacking args + * \brief Create combined call from ops that follow the initial combined op at the depth-th level. + * This usually involves concatenating or stacking inputs, then creating a new call. + * Only called if IsArgCompatbile returns true for each arg. + * \param data combined op + * \param branches branches of parallel ops to be combined + * \param depth depth at which to combine ops + * \param parent_index index of arg that corresponds to original input that was shared among + * all combined ops + * \return new combined call as batch op by stacking args */ Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, @@ -123,12 +123,12 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { size_t parent_index) final; /* - \brief Updates map of expr to substitute with combined expr. This usually involves - slicing or splitting data. - \param data combined op - \param branches branches of parallel ops to be combined - \param depth depth at which to substitute - \param subst_map map of Expr to replace with Expr to replace it with + * \brief Updates map of expr to substitute with combined expr. This usually involves + * slicing or splitting data. + * \param data combined op + * \param branches branches of parallel ops to be combined + * \param depth depth at which to substitute + * \param subst_map map of Expr to replace with Expr to replace it with */ void UpdateGroupOutput(const Expr& data, const Group& branches, @@ -137,9 +137,9 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { private: /* \brief name of op to replace combined ops with. for example, - for combining parallel dense, this will will be set to - nn.batch_matmul - */ + * for combining parallel dense, this will will be set to + * nn.batch_matmul + */ std::string batch_op_name_; };