Skip to content

Commit

Permalink
[POC][PatternLang]Remove constants from partitioned functions (apache…
Browse files Browse the repository at this point in the history
…#5663)

* remove constants from partitioned functions

* remove print statements
  • Loading branch information
Matthew Brookhart authored and kevinthesun committed Jun 2, 2020
1 parent 11fe762 commit ae560c8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ class PatternGrouper {
auto matches = node_map[node->ref_];
for (auto match : matches) {
if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
match.as<FunctionNode>() == nullptr && match.as<ConstantNode>() == nullptr) {
match.as<FunctionNode>() == nullptr) {
inputs[match] = Var(
"FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
Expand All @@ -577,8 +577,8 @@ class PatternGrouper {
auto extractor = MatchExtractor(inputs);
auto body = extractor.Mutate(expr);

// Verify the pattern still holds
CHECK(DFPatternMatcher(body).Match(pattern_, body));
// Verify the pattern still holds, no longer valid if we're not embedding constants in the
// graph, keep here for future debug CHECK(DFPatternMatcher(body).Match(pattern_, body));
group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
group.name = extractor.GetName();
// Check to make sure we aren't overlapping with another group
Expand Down
25 changes: 15 additions & 10 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,55 +878,60 @@ def nested_diamond(inp, weight):
)
assert tvm.ir.structural_equal(partitioned, reference)

def get_BN(x, var, mean, beta, gamma, eps = 1e-5):
return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta
def get_BN(x, var, mean, beta, gamma, eps):
return gamma * (x - mean)/relay.op.sqrt(var + eps) + beta

def test_partition_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
beta = relay.var('beta')
gamma = relay.var('gamma')
BN = get_BN(x, var, mean, beta, gamma)
eps = relay.const(1e-5)
BN = get_BN(x, var, mean, beta, gamma, eps)


xf = relay.var('xf')
varf = relay.var('varf')
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
epsf = relay.var('epsf')
# Put the arguments in toplogological order for the reference
f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta))
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta))

def test_partition_double_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
beta = relay.var('beta')
gamma = relay.var('gamma')
eps = relay.const(1e-5)

BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
BN = gamma * (x - mean)/relay.op.sqrt(var + eps) + beta
BN2 = gamma * (BN - mean)/relay.op.sqrt(var + eps) + beta

xf = relay.var('xf')
varf = relay.var('varf')
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
epsf = relay.var('epsf')
f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
# The partitioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
betaf2 = relay.var('betaf2')
gammaf2 = relay.var('gammaf2')
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
epsf2 = relay.var('epsf2')
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2, epsf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN2)
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta)
assert tvm.ir.structural_equal(partitioned, reference)

def test_partition_check():
Expand Down

0 comments on commit ae560c8

Please sign in to comment.