Skip to content

Commit

Permalink
Label Pattern Partitions with a default label to prevent nested parti…
Browse files Browse the repository at this point in the history
…tions and an optional user supplied-label
  • Loading branch information
Matthew Brookhart committed May 19, 2020
1 parent ec8f642 commit 2c24ece
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 15 deletions.
12 changes: 8 additions & 4 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,23 @@ def match(self, expr: Expr) -> bool:
"""
return match(self, expr)

def partition(self, expr: Expr) -> bool:
def partition(self, expr: Expr, attrs=None) -> Expr:
"""
Parition the expression into functions defined by this pattern
Parameters
----------
expr : tvm.relay.Expr
The expression to match.
attrs : dict[str->Object]
A dictionary of Attribute name/values to add to the paritioned function
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return partition(self, expr)
return partition(self, expr, attrs)

def dominates(self, parent, path=None):
"""
Expand Down Expand Up @@ -562,7 +564,7 @@ def rewrite(callbacks, expr: Expr) -> Expr:

return ffi.rewrite(tmp, expr)

def partition(pattern: DFPattern, expr: Expr) -> Expr:
def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
"""
Parition the expression into a series of functions that match the pattern
Expand All @@ -572,10 +574,12 @@ def partition(pattern: DFPattern, expr: Expr) -> Expr:
The pattern to match
expr : tvm.relay.Expr
The expression to split into functions
expr : dict[str->Object]
A dict of attributes to apply to the partitioned function
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return ffi.partition(pattern, expr)
return ffi.partition(pattern, expr, attrs)
29 changes: 23 additions & 6 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,17 @@ class PatternGrouper : protected MixedModeVisitor {
}

protected:
using ExprVisitor::VisitExpr_;
void VisitLeaf(const Expr& pre) override {
if (matcher_->Match(pattern_, pre)) {
CreateGroup(pre);
}
}

void VisitExpr_(const FunctionNode* op) override {
if (op->attrs->dict.count("Partitioned") == 0) {
ExprVisitor::VisitExpr_(op);
}
}
/* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
* group overlap analysis */
class MatchExtractor : public ExprMutator {
Expand Down Expand Up @@ -612,10 +617,12 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatt
*/
class PatternPartitioner : protected MixedModeMutator {
public:
Expr Partition(const DFPattern& pattern, const Expr& pre) {
Expr Partition(const DFPattern& pattern, const Expr& pre,
const Map<std::string, ObjectRef>& attrs) {
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(pattern, pre);
gid_assignments_ = grouper.GetGIDAssignments();
attrs_ = attrs;
return this->VisitExpr(pre);
}

Expand All @@ -625,7 +632,13 @@ class PatternPartitioner : protected MixedModeMutator {
for (size_t i = 0; i < group.args.size(); ++i) {
args.push_back(memo_[group.args[i]]);
}
return Call(group.function, args);
Function func = WithAttr(group.function, "Partitioned", String(""));
if (!attrs_.empty()) {
for (auto kv : attrs_) {
func = WithAttr(std::move(func), kv.first, kv.second);
}
}
return Call(func, args);
}

Expr DispatchVisitExpr(const Expr& pre) override {
Expand All @@ -636,15 +649,19 @@ class PatternPartitioner : protected MixedModeMutator {
return post;
}

Map<std::string, ObjectRef> attrs_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
};

Expr PartitionPattern(DFPattern pattern, Expr expr) {
return PatternPartitioner().Partition(pattern, expr);
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PatternPartitioner().Partition(pattern, expr, attrs);
}

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition").set_body_typed(PartitionPattern);
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition")
.set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PartitionPattern(pattern, expr, attrs);
});

} // namespace relay
} // namespace tvm
47 changes: 42 additions & 5 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,43 @@ def test_algebraic_simplify():

assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y)

def test_double_partition():
# Pattern 1
conv2d_p = is_op('nn.conv2d')(wildcard(), wildcard())
bias_add_p = is_op("nn.bias_add")(conv2d_p, wildcard())
relu_p = is_op('nn.relu')(bias_add_p)

# Graph
x = relay.var('input')
w = relay.var('weight')
b = relay.var('bias')
w2 = relay.var('weight')
b2 = relay.var('bias')
conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
conv2d2 = relay.op.nn.conv2d(relu, w2)
bias_add2 = relay.op.nn.bias_add(conv2d2, b2)

partitioned = bias_add2
for pat, label in [(relu_p, "conv_bias_relu"), (bias_add_p, "conv_bias")]:
partitioned = pat.partition(partitioned, {"Composite": label})


inpf = relay.var("input")
weightf = relay.var("weight")
biasf = relay.var("bias")
func0 = relay.Function([inpf, weightf, biasf], relay.op.nn.relu(relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf))).with_attr("Composite", "conv_bias_relu").with_attr("Partitioned","")
inpf = relay.var("input")
weightf = relay.var("weight")
biasf = relay.var("bias")
func1 = relay.Function([inpf, weightf, biasf], relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)).with_attr("Composite", "conv_bias").with_attr("Partitioned","")

expected = func1(func0(x, w, b), w2, b2)
print(partitioned)
print(expected)
assert tvm.ir.structural_equal(partitioned, expected)

def test_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
Expand All @@ -724,7 +761,7 @@ def generate_diamond(inp, weight):

i = relay.Var("input")
w = relay.Var("weight")
f = relay.Function([i, w], generate_diamond(i, w))
f = relay.Function([i, w], generate_diamond(i, w)).with_attr("Partitioned","")
assert tvm.ir.structural_equal(partitioned, f(inp*inp, weight*weight))

def test_quadruple_partition_dominator():
Expand Down Expand Up @@ -786,7 +823,7 @@ def nested_diamond(inp, weight):
for f in [classic_diamond, deeper_diamond, single_branch, nested_diamond]:
inpf = relay.var("input")
weightf = relay.var("weight")
functions.append(relay.Function([inpf, weightf], f(inpf, weightf)))
functions.append(relay.Function([inpf, weightf], f(inpf, weightf)).with_attr("Partitioned",""))

reference = functions[3](
functions[2](
Expand Down Expand Up @@ -816,7 +853,7 @@ def test_parition_batchnorm():
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
# Put the arguments in toplogological order for the reference
f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf))
f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("Partitioned","")

partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta))
Expand All @@ -836,14 +873,14 @@ def test_parition_double_batchnorm():
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf))
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("Partitioned","")
# The paritioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
betaf2 = relay.var('betaf2')
gammaf2 = relay.var('gammaf2')
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2))
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)).with_attr("Partitioned","")

partitioned = BatchnormCallback().pattern.partition(BN2)
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
Expand Down

0 comments on commit 2c24ece

Please sign in to comment.