Skip to content

Commit

Permalink
Label Pattern Partitions (apache#5627)
Browse files Browse the repository at this point in the history
* Label Pattern Partitions with a default label to prevent nested partitions and an optional user supplied-label

* Add node names in topological order to Partitioned attribute

* respond to review comments

* move partition tag into const in attr namespace
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Jun 18, 2020
1 parent bcd017b commit cf2a83b
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 18 deletions.
2 changes: 2 additions & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ constexpr const char* kSkipOptimization = "SkipOptimization";
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
} // namespace attr

} // namespace relay
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,23 @@ def match(self, expr: Expr) -> bool:
"""
return match(self, expr)

def partition(self, expr: Expr) -> bool:
def partition(self, expr: Expr, attrs=None) -> Expr:
"""
Parition the expression into functions defined by this pattern
Parameters
----------
expr : tvm.relay.Expr
The expression to match.
attrs : Optional[Dict[str, Object]]
A dictionary of Attribute name/values to add to the paritioned function
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return partition(self, expr)
return partition(self, expr, attrs)

def dominates(self, parent, path=None):
"""
Expand Down Expand Up @@ -562,7 +564,7 @@ def rewrite(callbacks, expr: Expr) -> Expr:

return ffi.rewrite(tmp, expr)

def partition(pattern: DFPattern, expr: Expr) -> Expr:
def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
"""
Parition the expression into a series of functions that match the pattern
Expand All @@ -572,10 +574,12 @@ def partition(pattern: DFPattern, expr: Expr) -> Expr:
The pattern to match
expr : tvm.relay.Expr
The expression to split into functions
expr : Optional[Dict[str, Object]]
A dict of attributes to apply to the partitioned function
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return ffi.partition(pattern, expr)
return ffi.partition(pattern, expr, attrs)
73 changes: 66 additions & 7 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ class PatternGrouper : protected MixedModeVisitor {
Expr root_node;
int gid;
Map<DFPattern, Array<Expr>> matched_nodes;
std::string name;
Function function;
Array<Expr> args;
};
Expand All @@ -409,19 +410,25 @@ class PatternGrouper : protected MixedModeVisitor {
}

protected:
using ExprVisitor::VisitExpr_;
void VisitLeaf(const Expr& pre) override {
if (matcher_->Match(pattern_, pre)) {
CreateGroup(pre);
}
}

void VisitExpr_(const FunctionNode* op) override {
if (op->attrs->dict.count(attr::kPartitionedFromPattern) == 0) {
ExprVisitor::VisitExpr_(op);
}
}
/* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
* group overlap analysis */
class MatchExtractor : public ExprMutator {
public:
explicit MatchExtractor(const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual>& inputs)
: inputs_(inputs) {}
const std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>& GetMemo() { return this->memo_; }
const std::string& GetName() { return name_; }

protected:
Expr VisitExpr(const Expr& pre) override {
Expand All @@ -430,6 +437,46 @@ class PatternGrouper : protected MixedModeVisitor {
}
return ExprMutator::VisitExpr(pre);
}
Expr VisitExpr_(const TupleNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "Tuple_";
return out;
};
Expr VisitExpr_(const FunctionNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "Function";
return out;
};
Expr VisitExpr_(const CallNode* call_node) override {
auto out = ExprMutator::VisitExpr_(call_node);
if (auto operation = call_node->op.as<OpNode>()) {
name_ += operation->name + "_";
} else {
name_ += "Call_";
}
return out;
};
Expr VisitExpr_(const LetNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "Let_";
return out;
};
Expr VisitExpr_(const IfNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "If_";
return out;
};
Expr VisitExpr_(const TupleGetItemNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "TupleGetItem" + std::to_string(op->index) + "_";
return out;
};
Expr VisitExpr_(const MatchNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "Match_";
return out;
};
std::string name_;
const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs_;
};

Expand Down Expand Up @@ -487,7 +534,7 @@ class PatternGrouper : protected MixedModeVisitor {
// Verify the pattern still holds
CHECK(DFPatternMatcher(body).Match(pattern_, body));
group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());

group.name = extractor.GetName();
// Check to make sure we aren't overlapping with another group
// The MatchExtractor will create a new graph by replacing nodes that match the inputs of the
// pattern with the input FunctionVar* Variables. The resulting memoization map will only
Expand Down Expand Up @@ -612,10 +659,12 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatt
*/
class PatternPartitioner : protected MixedModeMutator {
public:
Expr Partition(const DFPattern& pattern, const Expr& pre) {
Expr Partition(const DFPattern& pattern, const Expr& pre,
const Map<std::string, ObjectRef>& attrs) {
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(pattern, pre);
gid_assignments_ = grouper.GetGIDAssignments();
attrs_ = attrs;
return this->VisitExpr(pre);
}

Expand All @@ -625,7 +674,13 @@ class PatternPartitioner : protected MixedModeMutator {
for (size_t i = 0; i < group.args.size(); ++i) {
args.push_back(memo_[group.args[i]]);
}
return Call(group.function, args);
Function func = WithAttr(group.function, attr::kPartitionedFromPattern, String(group.name));
if (!attrs_.empty()) {
for (auto kv : attrs_) {
func = WithAttr(std::move(func), kv.first, kv.second);
}
}
return Call(func, args);
}

Expr DispatchVisitExpr(const Expr& pre) override {
Expand All @@ -636,15 +691,19 @@ class PatternPartitioner : protected MixedModeMutator {
return post;
}

Map<std::string, ObjectRef> attrs_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
};

Expr PartitionPattern(DFPattern pattern, Expr expr) {
return PatternPartitioner().Partition(pattern, expr);
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PatternPartitioner().Partition(pattern, expr, attrs);
}

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition").set_body_typed(PartitionPattern);
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition")
.set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PartitionPattern(pattern, expr, attrs);
});

} // namespace relay
} // namespace tvm
55 changes: 48 additions & 7 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,41 @@ def test_algebraic_simplify():

assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y)

def test_double_partition():
# Pattern 1
conv2d_p = is_op('nn.conv2d')(wildcard(), wildcard())
bias_add_p = is_op("nn.bias_add")(conv2d_p, wildcard())
relu_p = is_op('nn.relu')(bias_add_p)

# Graph
x = relay.var('input')
w = relay.var('weight')
b = relay.var('bias')
w2 = relay.var('weight')
b2 = relay.var('bias')
conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
conv2d2 = relay.op.nn.conv2d(relu, w2)
bias_add2 = relay.op.nn.bias_add(conv2d2, b2)

partitioned = bias_add2
for pat, label in [(relu_p, "conv_bias_relu"), (bias_add_p, "conv_bias")]:
partitioned = pat.partition(partitioned, {"Composite": label})


inpf = relay.var("input")
weightf = relay.var("weight")
biasf = relay.var("bias")
func0 = relay.Function([inpf, weightf, biasf], relay.op.nn.relu(relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf))).with_attr("Composite", "conv_bias_relu").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_")
inpf = relay.var("input")
weightf = relay.var("weight")
biasf = relay.var("bias")
func1 = relay.Function([inpf, weightf, biasf], relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)).with_attr("Composite", "conv_bias").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_")

expected = func1(func0(x, w, b), w2, b2)
assert tvm.ir.structural_equal(partitioned, expected)

def test_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
Expand All @@ -721,10 +756,10 @@ def generate_diamond(inp, weight):
out = generate_diamond(inp*inp, weight*weight)
# Check
partitioned = diamond.partition(out)

i = relay.Var("input")
w = relay.Var("weight")
f = relay.Function([i, w], generate_diamond(i, w))
f = relay.Function([i, w], generate_diamond(i, w)).with_attr("PartitionedFromPattern","nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_")
assert tvm.ir.structural_equal(partitioned, f(inp*inp, weight*weight))

def test_quadruple_partition_dominator():
Expand Down Expand Up @@ -783,10 +818,16 @@ def nested_diamond(inp, weight):
)

functions = []
for f in [classic_diamond, deeper_diamond, single_branch, nested_diamond]:
partition_names = [
"nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_",
"nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_",
"nn.conv2d_nn.relu_nn.relu_tanh_add_",
"nn.conv2d_nn.relu_add_tanh_nn.leaky_relu_add_"
]
for i, f in enumerate([classic_diamond, deeper_diamond, single_branch, nested_diamond]):
inpf = relay.var("input")
weightf = relay.var("weight")
functions.append(relay.Function([inpf, weightf], f(inpf, weightf)))
functions.append(relay.Function([inpf, weightf], f(inpf, weightf)).with_attr("PartitionedFromPattern", partition_names[i]))

reference = functions[3](
functions[2](
Expand Down Expand Up @@ -816,7 +857,7 @@ def test_parition_batchnorm():
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
# Put the arguments in toplogological order for the reference
f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf))
f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta))
Expand All @@ -836,14 +877,14 @@ def test_parition_double_batchnorm():
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf))
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
# The paritioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
betaf2 = relay.var('betaf2')
gammaf2 = relay.var('gammaf2')
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2))
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN2)
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
Expand Down

0 comments on commit cf2a83b

Please sign in to comment.