Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternLang]Conditionally Embedding Constants in Partitioned Functions #5693

Merged
merged 2 commits into from
May 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu:
pat.match(out)
The next example is matching a constant node regarding its values. This is useful to check
if a specific parameter in a subgraph has been bind or not.
if a specific parameter in a subgraph has been bound or not.

.. code-block:: python
Expand Down Expand Up @@ -266,10 +266,10 @@ Attribute Pattern

Check that the operator matched by the pattern has an attribute with a particular value.

Input
*****
Variable Pattern
****************

Check that the expression is an input, i.e has no parents and is a variable.
Check that the expression is a relay Variable, and optional provide a name to match to the Variable name.


Alternate
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,14 @@ class VarPattern(DFPattern):
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
The name of the variable. Optional, if not provided,
the pattern will match any VarNode.
type_annotation: tvm.relay.Type, optional
The type annotation on the variable.
"""

def __init__(self, name_hint: str, type_annotation=None):
def __init__(self, name_hint="", type_annotation=None):
self.__init_handle_by_constructor__(
ffi.VarPattern, name_hint, type_annotation)

Expand Down
36 changes: 33 additions & 3 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ class PatternGrouper {
auto matches = node_map[node->ref_];
for (auto match : matches) {
if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
match.as<FunctionNode>() == nullptr) {
match.as<FunctionNode>() == nullptr && !EmbedConst(match, node->ref_)) {
inputs[match] = Var(
"FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
Expand All @@ -582,8 +582,8 @@ class PatternGrouper {
auto extractor = MatchExtractor(inputs);
auto body = extractor.Mutate(expr);

// Verify the pattern still holds, no longer valid if we're not embedding constants in the
// graph, keep here for future debug CHECK(DFPatternMatcher(body).Match(pattern_, body));
// Verify the pattern still holds
CHECK(DFPatternMatcher(body).Match(pattern_, body));
group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
group.name = extractor.GetName();
// Check to make sure we aren't overlapping with another group
Expand Down Expand Up @@ -613,6 +613,36 @@ class PatternGrouper {
CHECK_EQ(groups_[gid_].gid, gid_);
}

/* \brief EmbedConst implements rules for embedding constants into partitioned functions or
* lifting them into the function arguments.
*
* The rules depend on what pattern the ConstantNode matched.
*
* The basic rules are:
* If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant
* in the partitioned function. If the constant matched an AltPattern, recursively check the
* matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc),
* lift the constant into the arguments of the partitioned function.
*/
bool EmbedConst(const Expr& expr, const DFPattern pattern) {
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
bool embed = false;
if (expr.as<ConstantNode>()) {
if (pattern.as<ConstantPatternNode>() != nullptr) {
embed = true;
} else if (auto expr_pat = pattern.as<ExprPatternNode>()) {
if (expr_pat->expr.as<ConstantNode>()) {
embed = true;
}
} else if (auto alt_pat = pattern.as<AltPatternNode>()) {
if (matcher_->Match(alt_pat->left, expr)) {
embed = EmbedConst(expr, alt_pat->left);
} else {
embed = EmbedConst(expr, alt_pat->right);
}
}
}
return embed;
}
// Internal State
DFPattern pattern_;
std::vector<Group> groups_;
Expand Down
107 changes: 85 additions & 22 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def __init__(self):
self.mean = wildcard()
self.beta = wildcard()
self.gamma = wildcard()
self.eps = wildcard()
self.eps = ConstantPattern()

self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \
self.beta
Expand Down Expand Up @@ -765,7 +765,7 @@ def algebraic_simplify(expr):

class ElwiseNullCallback(DFPatternCallback):
def callback(self, pre, post, node_map):
return node_map[self.x][0] # pylint: disable=no-member
return node_map[self.x][0] # pylint: disable=no-member

class AddCallback(ElwiseNullCallback):
def __init__(self):
Expand Down Expand Up @@ -1001,15 +1001,15 @@ def test_partition_batchnorm():
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
epsf = relay.var('epsf')
# Put the arguments in toplogological order for the reference
f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
f = relay.Function([gammaf, xf, meanf, varf, betaf],
get_BN(xf, varf, meanf, betaf, gammaf,
epsf)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")
eps)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta))
reference = f(gamma, x, mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)


def test_partition_double_batchnorm():
Expand All @@ -1028,25 +1028,23 @@ def test_partition_double_batchnorm():
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
epsf = relay.var('epsf')
f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
f1 = relay.Function([gammaf, xf, meanf, varf, betaf],
get_BN(xf, varf, meanf, betaf, gammaf,
epsf)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")
eps)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")
# The partitioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
betaf2 = relay.var('betaf2')
gammaf2 = relay.var('gammaf2')
epsf2 = relay.var('epsf2')
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2],
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2],
get_BN(xf2, varf2, meanf2, betaf2, gammaf2,
epsf2)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")
eps)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN2)
reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta)
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)


Expand Down Expand Up @@ -1106,6 +1104,13 @@ def check(pre):
assert relu == pattern.partition(relu, check=check)


def conv_bias_relu(x, w, b):
conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
return relu


def test_partition_option():
x = relay.var('x')
w = relay.var('w')
Expand All @@ -1119,12 +1124,6 @@ def test_partition_option():
bias = is_op('nn.bias_add')(conv2d, wildcard())
pattern2 = bias.optional(lambda x: is_op('nn.relu')(x))

def conv_bias_relu(x, w, b):
conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
return relu

relu = conv_bias_relu(x, w, b)

xf = relay.var('x')
Expand Down Expand Up @@ -1153,6 +1152,69 @@ def callback(self, pre, post, node_map):
out = rewrite(TestRewrite(), mod['tensor_concatenate_int64'])
assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out)

def test_partition_constant_embedding():
x = relay.var('x')
w = relay.var('w')
wc = relay.const(1)
b = relay.var('b')

xf = relay.var('x')
wf = relay.var('w')
bf = relay.var('b')
embeded_func = relay.Function([xf, bf],
conv_bias_relu(xf, wc,
bf)).with_attr("PartitionedFromPattern",
"nn.conv2d_nn.bias_add_nn.relu_")
xf = relay.var('x')
wf = relay.var('w')
bf = relay.var('b')
lifted_func = relay.Function([xf, wf, bf],
conv_bias_relu(xf, wf,
bf)).with_attr("PartitionedFromPattern",
"nn.conv2d_nn.bias_add_nn.relu_")
relu = conv_bias_relu(x, w, b)
reluc = conv_bias_relu(x, wc, b)

# Check lifting of wildcard matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), wildcard()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc))

# Check lifting of input matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs

# Check embedding of constant matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
ConstantPattern()),
wildcard()))
assert tvm.ir.structural_equal(relu, pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))

# Check embedding of constant ExprPatterns
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
ExprPattern(wc)),
wildcard()))
assert tvm.ir.structural_equal(relu, pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))

# Check lifting/embedding of Alt matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()
| ConstantPattern()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))

# Check lifting/embedding of Alt matches with the other ordering
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
wildcard(), ConstantPattern() | is_input()), wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))


if __name__ == "__main__":
test_expr_pattern()
test_var_pattern()
Expand Down Expand Up @@ -1209,3 +1271,4 @@ def callback(self, pre, post, node_map):
test_partition_check_types()
test_partition_option()
test_match_match()
test_partition_constant_embedding()