Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][PASS] CombineParallelConv2D #2089

Merged
merged 15 commits into from
Nov 22, 2018
5 changes: 5 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"CombineParallelConv2D": 1,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
Expand Down Expand Up @@ -144,6 +145,10 @@ def optimize(func, params=None):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)

if cfg.pass_enabled("CombineParallelConv2D"):
func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func)

if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func)
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,19 @@ def fuse_ops(expr, opt_level=1):
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr, opt_level)


def combine_parallel_conv2d(expr):
"""Fold multiple conv2d into one.

Parameters
----------
expr : tvm.relay.Expr
The input expression.

Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)
328 changes: 328 additions & 0 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
/*!
* Copyright (c) 2018 by Contributors
*
* \file combine_parallel_conv2d.cc
* \brief Combine parallel 2d convolutions into a single convolution.
*
* This pass replaces convolutions that share the same input node and the same
* arguments (except that the number of output channels can be different) with a
* single convolution. The weight of the new 2d convolution is the concatenation
* of the original weights. Elemwise and broadcast ops following conv2d are also
* combined if possible.
*
* This prevents launching multiple kernels in networks with multiple
* convolution branches, such as Inception block.
*/

#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"


namespace tvm {
namespace relay {

using Branch = std::vector<const CallNode*>;
using Group = std::vector<Branch>;

vinx13 marked this conversation as resolved.
Show resolved Hide resolved
/*
Find parallel branches starting with conv2d as shown below and then group branches by kernel
shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops.
Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
which should be handled in ParallelConv2DCombiner.

data
/ \
conv2d conv2d
| |
op op
| |
*/
class BranchGroupFinder : private ExprVisitor {
public:
std::vector<Group> Find(const Expr& expr) {
this->VisitExpr(expr);

std::vector<Group> groups;
for (const auto& root : conv_roots_) {
const auto& convs = children_map_.at(root);
for (const CallNode* conv : convs) {
auto&& branch = CreateBranch(conv);
// add the branch to a group, or create a new group
auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty());
return IsCompatibleConv2D(conv, group[0][0]);
});
if (it != groups.end()) {
it->push_back(branch);
} else {
groups.emplace_back();
// each group has at least one branch
groups.back().push_back(branch);
}
}
}
return groups;
}

private:
std::unordered_set<Expr, NodeHash, NodeEqual> conv_roots_;
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> children_map_;

// Two 2d convolutions can be combined if they have the same attributes or
// only have different output channels.
bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) {
AttrsEqual eq;
static const Layout kOIHW("OIHW");
const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
CHECK(attrs_a);
CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW);
const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW);

return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
eq(attrs_a->data_layout, attrs_b->data_layout) &&
eq(attrs_a->weight_layout, attrs_b->weight_layout) &&
eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) &&
eq(shape_a[3], shape_b[3]);
}

// Create a branch starting from conv2d.
Branch CreateBranch(const CallNode* conv) {
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
// each branch has at least one element, the first element is always conv2d
Branch branch{conv};
auto it = children_map_.find(GetRef<Expr>(branch.back()));
while (it != children_map_.end() && it->second.size() == 1) {
const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) {
branch.push_back(it->second[0]);
it = children_map_.find(GetRef<Expr>(branch.back()));
} else {
break;
}
}
return branch;
}

void VisitExpr_(const CallNode* n) final {
static const Op& conv2d = Op::Get("nn.conv2d");
ExprVisitor::VisitExpr_(n);
if (n->op.same_as(conv2d) && n->attrs.as<Conv2DAttrs>()->groups == 1) {
conv_roots_.insert(n->args[0]);
children_map_[n->args[0]].push_back(n);
} else {
for (size_t i = 0; i < n->args.size(); i++) {
children_map_[n->args[i]].push_back(n);
}
}
}
};

class ParallelConv2DCombiner {
public:
Expr Combine(const Expr& expr) {
auto groups = BranchGroupFinder().Find(expr);
for (const Group& group : groups) {
if (group.size() < 2) continue;
CombineBranches(group);
}
return ExprSubst(expr, std::move(subst_map_));
}

private:
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;

std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
Array<Expr> weights;
for (const auto& branch : branches) {
auto conv2d = branch[0];
weights.push_back(conv2d->args[1]);
auto channels = GetConv2DSuperChannelsDim(conv2d);
num_filters += channels;
}
auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->weight_layout.find('O');
CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
MakeConstScalar(Int(32), num_filters));
}

Call MakeCombinedConv2D(const Group& branches) {
static const Op& conv2d = Op::Get("nn.conv2d");
Expr data = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_channels;
std::tie(new_weight, new_channels) = TransformWeight(branches);

const CallNode* group_root = branches[0][0];
const auto* attrs = group_root->attrs.as<Conv2DAttrs>();
CHECK(attrs);
const auto new_attrs = make_node<Conv2DAttrs>();
new_attrs->strides = attrs->strides;
new_attrs->padding = attrs->padding;
new_attrs->dilation = attrs->dilation;
new_attrs->groups = attrs->groups;
new_attrs->kernel_size = attrs->kernel_size;
new_attrs->data_layout = attrs->data_layout;
new_attrs->weight_layout = attrs->weight_layout;
new_attrs->out_layout = attrs->out_layout;
new_attrs->out_dtype = attrs->out_dtype;
new_attrs->channels = new_channels;

return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
}

bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) {
AttrsEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>();
auto toutput_b = b->type_as<TensorTypeNode>();

if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size())
return false;

// Position of the 'C' dimension in the argument
size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size();

// Channel super-dimension shoule be present and not broadcasted
if ((arg_channel_pos > channel_pos) || // size_t overflow
!eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) ||
!eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos]))
return false;

for (size_t i = 0; i < ta->shape.size(); i++) {
if (i == arg_channel_pos) continue;
if (!eq(ta->shape[i], tb->shape[i]))
return false;
}
return true;
}

// Check if ops in depth-th level can be combined
bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) {
const CallNode* call = branches[0][depth];
AttrsEqual attrs_equal;
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it;
if (!branch[depth]->op.same_as(call->op) ||
!attrs_equal(branch[depth]->attrs, call->attrs) ||
branch[depth]->args.size() != call->args.size()) {
return false;
}

if (branch[depth]->args[parent_index].get() != branch[depth - 1])
return false;

// Check args
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) continue;

if (!IsArgCompatible(call, branch[depth], i, channel_pos) ||
!attrs_equal(call->attrs, branch[depth]->attrs)) {
return false;
}
}
}
return true;
}

// Combine args and make the combined CallNode
Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
size_t ndim = call->type_as<TensorTypeNode>()->shape.size();

for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
new_args.push_back(data);
continue;
}
size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
size_t arg_channel_pos = channel_pos - ndim + arg_ndim;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you give an example of values of ndim and arg_ndim? When can channel_pos and arg_channel_pos be different?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim is the dimension of output of con2
d, typically 4 if the layout is NCHW; arg_ndim is the dimension of argument, which doesn’t have to have the batch dimension, for example the per-channel bias have the shape [channels, 1, 1], and arg_ndim is 3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, then for NCHW layout arg_channel_pos is 0, and for NHWC, the bias shape should be [1, 1, channels], so arg_channel_pos is 2 ( 3 - 4 + 3). The math seems to check out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes in this case. channel_pos - ndim + arg_ndim can be seen as arg_ndim - (ndim - channel_pos), which means start marching each dimension from the right (as per the broadcast semantic)

Array<Expr> tuple;
for (const auto& branch : branches) {
tuple.push_back(branch[depth]->args[i]);
}
auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos);
new_args.push_back(std::move(concat));
}
return CallNode::make(call->op, new_args, call->attrs, {});
}

// Replace output of each branch with slices of the combined output
void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
size_t channel_pos) {
int64_t index = 0;
for (const auto& branch : branches) {
const CallNode* conv2d = branch[0];
int64_t channels = GetConv2DSuperChannelsDim(conv2d);
Array<Integer> begin;
Array<Integer> end;
for (size_t i = 0; i < channel_pos; i++) {
begin.push_back(0);
end.push_back(NullValue<Integer>());
}
begin.push_back(index);
index += channels;
end.push_back(index);
auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{});
subst_map_[GetRef<Expr>(branch[depth])] = slice;
}
}

// Combine branches in a group. Conv2d in different branches in the same group are safe to
// combine. Subsequent ops may or may not be combined. We start from conv2d and try to
// combine ops from all branches in the same depth.
void CombineBranches(const Group& branches) {
Call combined = MakeCombinedConv2D(branches);
auto conv_param = combined->attrs.as<Conv2DAttrs>();
const std::string& layout =
conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout;
size_t channel_pos = layout.find('C');
CHECK_NE(channel_pos, std::string::npos);
auto it = std::min_element(branches.begin(), branches.end(),
[](const Branch& branch_a,
const Branch& branch_b) {
return branch_a.size() < branch_b.size();
});
size_t depth = it->size();
size_t i;
// starting from 1 to skip the conv2d
for (i = 1; i < depth; i++) {
size_t parent_index;
for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) {
if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break;
}
CHECK_NE(parent_index, branches[0][i]->args.size());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When branches[0][i] is a node where multiple branches meet, parent_index can be > 0. This is the only case when parent_index can be > 0. Correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another case is that the parent is on rhs of binary op

if (!CheckLevel(branches, i, channel_pos, parent_index)) break;
combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index);
}
UpdateGroupOutput(combined, branches, i - 1, channel_pos);
}
};

Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); }

TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CombineParallelConv2D(args[0]);
});

} // namespace relay
} // namespace tvm
Loading