Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Label Pattern Partitions #5627

Merged
merged 4 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ constexpr const char* kSkipOptimization = "SkipOptimization";
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
} // namespace attr

} // namespace relay
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,23 @@ def match(self, expr: Expr) -> bool:
"""
return match(self, expr)

def partition(self, expr: Expr) -> bool:
def partition(self, expr: Expr, attrs=None) -> Expr:
"""
Parition the expression into functions defined by this pattern

Parameters
----------
expr : tvm.relay.Expr
The expression to match.
attrs : Optional[Dict[str, Object]]
A dictionary of Attribute name/values to add to the paritioned function

Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return partition(self, expr)
return partition(self, expr, attrs)

def dominates(self, parent, path=None):
"""
Expand Down Expand Up @@ -562,7 +564,7 @@ def rewrite(callbacks, expr: Expr) -> Expr:

return ffi.rewrite(tmp, expr)

def partition(pattern: DFPattern, expr: Expr) -> Expr:
def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
"""
Parition the expression into a series of functions that match the pattern

Expand All @@ -572,10 +574,12 @@ def partition(pattern: DFPattern, expr: Expr) -> Expr:
The pattern to match
expr : tvm.relay.Expr
The expression to split into functions
expr : Optional[Dict[str, Object]]
A dict of attributes to apply to the partitioned function

Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return ffi.partition(pattern, expr)
return ffi.partition(pattern, expr, attrs)
73 changes: 66 additions & 7 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ class PatternGrouper : protected MixedModeVisitor {
Expr root_node;
int gid;
Map<DFPattern, Array<Expr>> matched_nodes;
std::string name;
Function function;
Array<Expr> args;
};
Expand All @@ -409,19 +410,25 @@ class PatternGrouper : protected MixedModeVisitor {
}

protected:
using ExprVisitor::VisitExpr_;
void VisitLeaf(const Expr& pre) override {
if (matcher_->Match(pattern_, pre)) {
CreateGroup(pre);
}
}

void VisitExpr_(const FunctionNode* op) override {
if (op->attrs->dict.count(attr::kPartitionedFromPattern) == 0) {
ExprVisitor::VisitExpr_(op);
}
}
/* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
* group overlap analysis */
class MatchExtractor : public ExprMutator {
public:
explicit MatchExtractor(const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual>& inputs)
: inputs_(inputs) {}
const std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>& GetMemo() { return this->memo_; }
const std::string& GetName() { return name_; }

protected:
Expr VisitExpr(const Expr& pre) override {
Expand All @@ -430,6 +437,46 @@ class PatternGrouper : protected MixedModeVisitor {
}
return ExprMutator::VisitExpr(pre);
}
Expr VisitExpr_(const TupleNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "Tuple_";
return out;
};
Expr VisitExpr_(const FunctionNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "Function";
return out;
};
Expr VisitExpr_(const CallNode* call_node) override {
auto out = ExprMutator::VisitExpr_(call_node);
if (auto operation = call_node->op.as<OpNode>()) {
name_ += operation->name + "_";
} else {
name_ += "Call_";
}
return out;
};
Expr VisitExpr_(const LetNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "Let_";
return out;
};
Expr VisitExpr_(const IfNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "If_";
return out;
};
Expr VisitExpr_(const TupleGetItemNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "TupleGetItem" + std::to_string(op->index) + "_";
return out;
};
Expr VisitExpr_(const MatchNode* op) override {
auto out = ExprMutator::VisitExpr_(op);
name_ += "Match_";
return out;
};
std::string name_;
const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs_;
};

Expand Down Expand Up @@ -487,7 +534,7 @@ class PatternGrouper : protected MixedModeVisitor {
// Verify the pattern still holds
CHECK(DFPatternMatcher(body).Match(pattern_, body));
group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());

group.name = extractor.GetName();
// Check to make sure we aren't overlapping with another group
// The MatchExtractor will create a new graph by replacing nodes that match the inputs of the
// pattern with the input FunctionVar* Variables. The resulting memoization map will only
Expand Down Expand Up @@ -612,10 +659,12 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatt
*/
class PatternPartitioner : protected MixedModeMutator {
public:
Expr Partition(const DFPattern& pattern, const Expr& pre) {
Expr Partition(const DFPattern& pattern, const Expr& pre,
const Map<std::string, ObjectRef>& attrs) {
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(pattern, pre);
gid_assignments_ = grouper.GetGIDAssignments();
attrs_ = attrs;
return this->VisitExpr(pre);
}

Expand All @@ -625,7 +674,13 @@ class PatternPartitioner : protected MixedModeMutator {
for (size_t i = 0; i < group.args.size(); ++i) {
args.push_back(memo_[group.args[i]]);
}
return Call(group.function, args);
Function func = WithAttr(group.function, attr::kPartitionedFromPattern, String(group.name));
if (!attrs_.empty()) {
masahi marked this conversation as resolved.
Show resolved Hide resolved
for (auto kv : attrs_) {
func = WithAttr(std::move(func), kv.first, kv.second);
}
}
return Call(func, args);
}

Expr DispatchVisitExpr(const Expr& pre) override {
Expand All @@ -636,15 +691,19 @@ class PatternPartitioner : protected MixedModeMutator {
return post;
}

Map<std::string, ObjectRef> attrs_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
};

Expr PartitionPattern(DFPattern pattern, Expr expr) {
return PatternPartitioner().Partition(pattern, expr);
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PatternPartitioner().Partition(pattern, expr, attrs);
}

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition").set_body_typed(PartitionPattern);
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition")
.set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PartitionPattern(pattern, expr, attrs);
});

} // namespace relay
} // namespace tvm
55 changes: 48 additions & 7 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,41 @@ def test_algebraic_simplify():

assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y)

def test_double_partition():
# Pattern 1
conv2d_p = is_op('nn.conv2d')(wildcard(), wildcard())
bias_add_p = is_op("nn.bias_add")(conv2d_p, wildcard())
relu_p = is_op('nn.relu')(bias_add_p)

# Graph
x = relay.var('input')
w = relay.var('weight')
b = relay.var('bias')
w2 = relay.var('weight')
b2 = relay.var('bias')
conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
conv2d2 = relay.op.nn.conv2d(relu, w2)
bias_add2 = relay.op.nn.bias_add(conv2d2, b2)

partitioned = bias_add2
for pat, label in [(relu_p, "conv_bias_relu"), (bias_add_p, "conv_bias")]:
partitioned = pat.partition(partitioned, {"Composite": label})


inpf = relay.var("input")
weightf = relay.var("weight")
biasf = relay.var("bias")
func0 = relay.Function([inpf, weightf, biasf], relay.op.nn.relu(relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf))).with_attr("Composite", "conv_bias_relu").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_")
inpf = relay.var("input")
weightf = relay.var("weight")
biasf = relay.var("bias")
func1 = relay.Function([inpf, weightf, biasf], relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)).with_attr("Composite", "conv_bias").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_")

expected = func1(func0(x, w, b), w2, b2)
assert tvm.ir.structural_equal(partitioned, expected)

def test_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
Expand All @@ -721,10 +756,10 @@ def generate_diamond(inp, weight):
out = generate_diamond(inp*inp, weight*weight)
# Check
partitioned = diamond.partition(out)

i = relay.Var("input")
w = relay.Var("weight")
f = relay.Function([i, w], generate_diamond(i, w))
f = relay.Function([i, w], generate_diamond(i, w)).with_attr("PartitionedFromPattern","nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_")
assert tvm.ir.structural_equal(partitioned, f(inp*inp, weight*weight))

def test_quadruple_partition_dominator():
Expand Down Expand Up @@ -783,10 +818,16 @@ def nested_diamond(inp, weight):
)

functions = []
for f in [classic_diamond, deeper_diamond, single_branch, nested_diamond]:
partition_names = [
"nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_",
"nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_",
"nn.conv2d_nn.relu_nn.relu_tanh_add_",
"nn.conv2d_nn.relu_add_tanh_nn.leaky_relu_add_"
]
for i, f in enumerate([classic_diamond, deeper_diamond, single_branch, nested_diamond]):
inpf = relay.var("input")
weightf = relay.var("weight")
functions.append(relay.Function([inpf, weightf], f(inpf, weightf)))
functions.append(relay.Function([inpf, weightf], f(inpf, weightf)).with_attr("PartitionedFromPattern", partition_names[i]))

reference = functions[3](
functions[2](
Expand Down Expand Up @@ -816,7 +857,7 @@ def test_parition_batchnorm():
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
# Put the arguments in toplogological order for the reference
f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf))
f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta))
Expand All @@ -836,14 +877,14 @@ def test_parition_double_batchnorm():
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf))
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
# The paritioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
betaf2 = relay.var('betaf2')
gammaf2 = relay.var('gammaf2')
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2))
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN2)
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
Expand Down