Skip to content

Commit

Permalink
Revert 's/graph_equal/is_unifiable' change
Browse files Browse the repository at this point in the history
  • Loading branch information
weberlo committed Sep 5, 2019
1 parent 8c71f17 commit 8f63f9d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def assert_alpha_equal(lhs, rhs):
_make._assert_alpha_equal(lhs, rhs)


def is_unifiable(lhs, rhs):
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
Expand All @@ -273,10 +273,10 @@ def is_unifiable(lhs, rhs):
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._is_unifiable(lhs, rhs))
return bool(_make._graph_equal(lhs, rhs))


def assert_is_unifiable(lhs, rhs):
def assert_graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
Expand All @@ -290,7 +290,7 @@ def assert_is_unifiable(lhs, rhs):
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make._assert_is_unifiable(lhs, rhs)
_make._assert_graph_equal(lhs, rhs)


def collect_device_info(expr):
Expand Down
8 changes: 4 additions & 4 deletions src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -596,15 +596,15 @@ TVM_REGISTER_API("relay._make._assert_alpha_equal")
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
});

TVM_REGISTER_API("relay._make._is_unifiable")
TVM_REGISTER_API("relay._make._graph_equal")
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(true, false).Equal(a, b);
});

TVM_REGISTER_API("relay._make._assert_is_unifiable")
TVM_REGISTER_API("relay._make._assert_graph_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool is_unifiable = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(is_unifiable) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
});

} // namespace relay
Expand Down
14 changes: 7 additions & 7 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import tvm
from tvm import relay
from tvm.relay.analysis import is_unifiable, assert_is_unifiable
from tvm.relay.analysis import graph_equal, assert_graph_equal
from nose.tools import nottest, raises
from numpy import isclose
from typing import Union
Expand Down Expand Up @@ -69,7 +69,7 @@

def roundtrip(expr):
x = relay.fromtext(str(expr))
assert_is_unifiable(x, expr)
assert_graph_equal(x, expr)


def parse_text(code):
Expand All @@ -81,7 +81,7 @@ def parse_text(code):
def parses_as(code, expr):
# type: (str, relay.Expr) -> bool
parsed = parse_text(code)
result = is_unifiable(parsed, expr)
result = graph_equal(parsed, expr)
return result

def get_scalar(x):
Expand Down Expand Up @@ -177,13 +177,13 @@ def test_bin_op():


def test_parens():
assert is_unifiable(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
assert not is_unifiable(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))


def test_op_assoc():
assert is_unifiable(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
assert is_unifiable(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))


@nottest
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tvm.relay.testing
import numpy as np
from tvm.relay import Expr
from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_is_unifiable, free_vars
from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equal, free_vars

do_print = [False]

Expand All @@ -31,7 +31,7 @@ def astext(p, unify_free_vars=False):
return txt
x = relay.fromtext(txt)
if unify_free_vars:
assert_is_unifiable(x, p)
assert_graph_equal(x, p)
else:
assert_alpha_equal(x, p)
return txt
Expand Down

0 comments on commit 8f63f9d

Please sign in to comment.