Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add new IR pass CombineParallelDense #3862

Merged
merged 21 commits into from
Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/relay/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ tvm.relay.transform

.. autofunction:: tvm.relay.transform.CombineParallelConv2D

.. autofunction:: tvm.relay.transform.CombineParallelDense

.. autofunction:: tvm.relay.transform.AlterOpLayout

.. autofunction:: tvm.relay.transform.Legalize
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,17 @@ TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);
*/
TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);

/*!
* \brief Combine parallel dense ops into a single batch_matmul if the
* number of branches of this dense operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3);

/*!
* \brief Backward fold axis scaling into weights of conv/dense operators.
*
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
"CombineParallelDense": 4
}

fallback_device : int, str, or tvm.TVMContext, optional
Expand Down Expand Up @@ -400,6 +401,23 @@ def CombineParallelConv2D(min_num_branches=3):
return _transform.CombineParallelConv2D(min_num_branches)


def CombineParallelDense(min_num_branches=3):
"""Combine multiple dense operators into one.

Parameters
----------
min_num_branches : int
The minimum number of required parallel branches for performing this
soiferj marked this conversation as resolved.
Show resolved Hide resolved
optimization.

Returns
-------
ret: tvm.relay.Pass
The registered pass that combines parallel dense operators.
"""
return _transform.CombineParallelDense(min_num_branches)


def AlterOpLayout():
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode {
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
Expand Down
230 changes: 44 additions & 186 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
*
* \file combine_parallel_conv2d.cc
* \brief Combine parallel 2d convolutions into a single convolution.
Expand All @@ -43,66 +43,23 @@
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"

#include "./combine_parallel_op.h"

namespace tvm {
namespace relay {

using Branch = std::vector<const CallNode*>;
using Group = std::vector<Branch>;

/*
Find parallel branches starting with conv2d as shown below and then group branches by kernel
shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops.
Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
which should be handled in ParallelConv2DCombiner.

data
/ \
conv2d conv2d
| |
op op
| |
*/
class BranchGroupFinder : private ExprVisitor {
class ParallelConv2DCombiner : public ParallelOpCombiner {
public:
std::vector<Group> Find(const Expr& expr) {
static const Op& conv2d = Op::Get("nn.conv2d");

this->VisitExpr(expr);

std::vector<Group> groups;
for (const auto& root : conv_roots_) {
const auto& children = children_map_.at(root);
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(conv2d)) continue;

auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty());
return IsCompatibleConv2D(child, group[0][0]);
});
if (it != groups.end()) {
it->push_back(branch);
} else {
groups.emplace_back();
// each group has at least one branch
groups.back().push_back(branch);
}
}
}
return groups;
explicit ParallelConv2DCombiner(uint64_t min_num_branches)
: ParallelOpCombiner("nn.conv2d", min_num_branches) {
}

private:
std::unordered_set<Expr, NodeHash, NodeEqual> conv_roots_;
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> children_map_;
protected:
bool IsSupportedOp(const CallNode* n) {
return n->attrs.as<Conv2DAttrs>()->groups == 1;
}

// Two 2d convolutions can be combined if they have the same attributes or
// only have different output channels.
bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) {
bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
AttrsEqual eq;
static const Layout kOIHW("OIHW");
const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
Expand All @@ -125,75 +82,7 @@ class BranchGroupFinder : private ExprVisitor {
eq(shape_a[3], shape_b[3]);
}

// Create a branch starting from conv2d.
Branch CreateBranch(const CallNode* conv) {
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
// each branch has at least one element, the first element is always conv2d
Branch branch{conv};
auto it = children_map_.find(GetRef<Expr>(branch.back()));
while (it != children_map_.end() && it->second.size() == 1) {
const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) {
branch.push_back(call);
it = children_map_.find(GetRef<Expr>(branch.back()));
} else {
break;
}
}
return branch;
}

void VisitExpr_(const CallNode* n) final {
static const Op& conv2d = Op::Get("nn.conv2d");
ExprVisitor::VisitExpr_(n);
if (n->op.same_as(conv2d) && n->attrs.as<Conv2DAttrs>()->groups == 1) {
conv_roots_.insert(n->args[0]);
children_map_[n->args[0]].push_back(n);
} else {
for (size_t i = 0; i < n->args.size(); i++) {
children_map_[n->args[i]].push_back(n);
}
}
}
};

class ParallelConv2DCombiner {
public:
explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) {
}

Expr Combine(const Expr& expr) {
auto groups = BranchGroupFinder().Find(expr);
for (const Group& group : groups) {
if (group.size() < min_num_branches_) {
continue;
}
CombineBranches(group);
}
return ExprSubst(expr, std::move(subst_map_));
}

private:
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;
uint64_t min_num_branches_;

std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
Array<Expr> weights;
for (const auto& branch : branches) {
auto conv2d = branch[0];
weights.push_back(conv2d->args[1]);
auto channels = GetConv2DSuperChannelsDim(conv2d);
num_filters += channels;
}
auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.find('O');
CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
MakeConstScalar(Int(32), num_filters));
}

Call MakeCombinedConv2D(const Group& branches) {
Call MakeCombinedOp(const Group& branches) {
static const Op& conv2d = Op::Get("nn.conv2d");
Expr data = branches[0][0]->args[0];
Expr new_weight;
Expand All @@ -215,10 +104,15 @@ class ParallelConv2DCombiner {
new_attrs->out_dtype = attrs->out_dtype;
new_attrs->channels = new_channels;

const std::string& layout =
new_attrs->out_layout == "" ? new_attrs->data_layout : new_attrs->out_layout;
channel_pos = layout.find('C');
CHECK_NE(channel_pos, std::string::npos);
soiferj marked this conversation as resolved.
Show resolved Hide resolved

return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
}

bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) {
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
AttrsEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
Expand All @@ -245,38 +139,10 @@ class ParallelConv2DCombiner {
return true;
}

// Check if ops in depth-th level can be combined
bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) {
const CallNode* call = branches[0][depth];
AttrsEqual attrs_equal;
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it;
if (!branch[depth]->op.same_as(call->op) ||
!attrs_equal(branch[depth]->attrs, call->attrs) ||
branch[depth]->args.size() != call->args.size()) {
return false;
}

if (branch[depth]->args[parent_index].get() != branch[depth - 1])
return false;

// Check args
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) continue;

if (!IsArgCompatible(call, branch[depth], i, channel_pos) ||
!attrs_equal(call->attrs, branch[depth]->attrs)) {
return false;
}
}
}
return true;
}

// Combine args and make the combined CallNode
Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos,
size_t parent_index) {
Call MakeCombinedCallFromFollowingOps(const Expr& data,
const Group& branches,
size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
size_t ndim = call->type_as<TensorTypeNode>()->shape.size();
Expand All @@ -286,21 +152,25 @@ class ParallelConv2DCombiner {
new_args.push_back(data);
continue;
}

size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
size_t arg_channel_pos = channel_pos - ndim + arg_ndim;
Array<Expr> tuple;
for (const auto& branch : branches) {
tuple.push_back(branch[depth]->args[i]);
}

auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos);
new_args.push_back(std::move(concat));
}

return CallNode::make(call->op, new_args, call->attrs, {});
}

// Replace output of each branch with slices of the combined output
void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
size_t channel_pos) {
void UpdateGroupOutput(const Expr& data,
const Group& branches,
size_t depth,
ExprSubstMap* subst_map) {
int64_t index = 0;
for (const auto& branch : branches) {
const CallNode* conv2d = branch[0];
Expand All @@ -315,38 +185,26 @@ class ParallelConv2DCombiner {
index += channels;
end.push_back(index);
auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{});
subst_map_[GetRef<Expr>(branch[depth])] = slice;
subst_map->insert({GetRef<Expr>(branch[depth]), slice});
}
}

// Combine branches in a group. Conv2d in different branches in the same group are safe to
// combine. Subsequent ops may or may not be combined. We start from conv2d and try to
// combine ops from all branches in the same depth.
void CombineBranches(const Group& branches) {
Call combined = MakeCombinedConv2D(branches);
auto conv_param = combined->attrs.as<Conv2DAttrs>();
const std::string& layout =
conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout;
size_t channel_pos = layout.find('C');
CHECK_NE(channel_pos, std::string::npos);
auto it = std::min_element(branches.begin(), branches.end(),
[](const Branch& branch_a,
const Branch& branch_b) {
return branch_a.size() < branch_b.size();
});
size_t depth = it->size();
size_t i;
// starting from 1 to skip the conv2d
for (i = 1; i < depth; i++) {
size_t parent_index;
for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) {
if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break;
}
CHECK_NE(parent_index, branches[0][i]->args.size());
if (!CheckLevel(branches, i, channel_pos, parent_index)) break;
combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index);
private:
size_t channel_pos;

std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
Array<Expr> weights;
for (const auto& branch : branches) {
auto conv2d = branch[0];
weights.push_back(conv2d->args[1]);
auto channels = GetConv2DSuperChannelsDim(conv2d);
num_filters += channels;
}
UpdateGroupOutput(combined, branches, i - 1, channel_pos);
auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.find('O');
CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
MakeConstScalar(Int(32), num_filters));
}
};

Expand Down
Loading