Skip to content

Commit

Permalink
[Relay] Serialization round-trip tests (apache#1968)
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky authored and AWS Neo committed Feb 20, 2019
1 parent 84149cb commit 3d568dd
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 7 deletions.
22 changes: 22 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,25 @@ def alpha_equal(lhs, rhs):
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))


def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs: tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
2 changes: 1 addition & 1 deletion src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class AlphaEqualHandler:

bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
if (!lhs->func.same_as(rhs->func)) return false;
if (lhs->func->name != rhs->func->name) return false;
if (lhs->num_inputs != rhs->num_inputs) return false;
if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
Expand Down
44 changes: 38 additions & 6 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
import tvm
from tvm import relay
from tvm.expr import *
from tvm.relay.ir_pass import graph_equal


def check_json_roundtrip(node):
json_str = tvm.save_json(node)
back = tvm.load_json(json_str)
assert graph_equal(back, node)


def test_bad_constructor():
try:
Expand All @@ -21,6 +29,13 @@ def test_span():
assert isinstance(span, relay.base.Span)
str(span)

# span is not a node so we can't use graph_equal
# to test the round trip
back = tvm.load_json(tvm.save_json(span))
assert back.source == span.source
assert back.lineno == span.lineno
assert back.col_offset == span.col_offset

# Types

def test_tensor_type():
Expand All @@ -31,28 +46,31 @@ def test_tensor_type():
assert tt.shape == shape
assert tt.span == None
str(tt)
check_json_roundtrip(tt)


def test_type_param():
tp = relay.TypeVar('name', relay.Kind.Type)
assert tp.kind == relay.Kind.Type
# assert tp.span # TODO allow us to set span
str(tp)
check_json_roundtrip(tp)


def test_func_type():
type_params = tvm.convert([])
type_constraints = tvm.convert([]) # TODO: fill me in
arg_types = tvm.convert([])
ret_type = None
ret_type = relay.TensorType((1, 2, 3), 'float32')
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert tf.type_params == type_params
assert tf.type_constraints == type_constraints
assert tf.arg_types == arg_types
assert tf.ret_type == ret_type
assert tf.span == None
# TODO make sure we can set
# TODO make sure we can set span
str(tf)
check_json_roundtrip(tf)


def test_tuple_type():
Expand All @@ -63,13 +81,15 @@ def test_tuple_type():

tup_ty = relay.TupleType(fields)
assert tup_ty.fields == fields
str(tup_ty)
check_json_roundtrip(tup_ty)


def test_type_relation():
tp = relay.TypeVar('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
args = tvm.convert([tf, tt, tp])
args = tvm.convert([tp, tf, tt])

num_inputs = 2
func = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
Expand All @@ -78,6 +98,8 @@ def test_type_relation():
tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
assert tr.num_inputs == num_inputs
str(tr)
check_json_roundtrip(tr)


def test_constant():
Expand All @@ -86,6 +108,7 @@ def test_constant():
assert const.data == arr
assert const.span == None
str(const)
check_json_roundtrip(const)


def test_tuple():
Expand All @@ -94,6 +117,7 @@ def test_tuple():
assert tup.fields == fields
assert tup.span == None
str(tup)
check_json_roundtrip(tup)


def test_local_var():
Expand All @@ -103,6 +127,7 @@ def test_local_var():
assert lv.type_annotation is None
# assert lv.span == None todo(@jroesch): what do we do about spans
str(lv)
check_json_roundtrip(lv)

t1 = relay.ty.TensorType((), "float")
lv = relay.Var(name_hint, t1)
Expand All @@ -116,20 +141,22 @@ def test_global_var():
gv.name_hint == name_hint
# assert lv.span == None todo(@jroesch): what do we do about spans
str(gv)
check_json_roundtrip(gv)


def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
body = None
ret_type = relay.TupleType(tvm.convert([]))
body = relay.Tuple(tvm.convert([]))
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
fn = relay.Function(params, body, ret_type, type_params)
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
str(fn)
check_json_roundtrip(fn)


def test_call():
Expand All @@ -141,6 +168,7 @@ def test_call():
assert call.args == args
assert call.span == None
str(call)
check_json_roundtrip(call)


def test_let():
Expand All @@ -156,6 +184,7 @@ def test_let():
assert let.body == lv
assert let.span == None
str(let)
check_json_roundtrip(let)


def test_if():
Expand All @@ -168,6 +197,7 @@ def test_if():
assert ife.false_branch == right
assert ife.span == None
str(ife)
check_json_roundtrip(ife)


def test_tuple_get_item():
Expand All @@ -176,6 +206,8 @@ def test_tuple_get_item():
assert get.tuple_value == tup
assert get.index == 1
str(get)
check_json_roundtrip(get)


if __name__ == "__main__":
test_bad_constructor()
Expand Down

0 comments on commit 3d568dd

Please sign in to comment.