Skip to content

Commit

Permalink
[Relay] [Parser] fix parser for cast. (apache#3873)
Browse files Browse the repository at this point in the history
* fix

* lint
  • Loading branch information
MarisaKirisame authored and wweic committed Sep 16, 2019
1 parent 31a78a2 commit 616f1e6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __call__(self, args, attrs, type_args):
try:
return expr.Call(self.operator, args, attrs, type_args)
except Exception:
raise Exception(str(self.operator) + " " + str(attrs))
raise Exception("Operator {} is not registered. It's attributes are {}"
.format(self.operator, attrs))

class FuncOp(OpWrapper):
"""Convert the attrs, call the python function with the attrs passed in as keyword arguments.
Expand Down Expand Up @@ -132,6 +133,7 @@ def __call__(self, args, attrs, type_args):
"nn.dropout": op.nn.dropout_raw,
"zeros": op.zeros,
"split": op.split,
"cast": op.cast
}

TYPE_PREFIXES = [
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,23 @@ def test_inception_v3():
net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
astext(net)


def test_squeezenet():
for version in ['1.0', '1.1']:
net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
astext(net)


def test_vgg():
net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
astext(net)


def test_densenet():
net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
astext(net)


def test_call_node_order():
x = relay.var("x")
y = relay.var("y")
Expand All @@ -196,6 +200,7 @@ def test_call_node_order():
"};\n"
"%2(%1)")


def test_let_inlining():
tup = relay.Tuple([relay.const(0), relay.const(0)])
x = relay.var("x")
Expand All @@ -208,10 +213,19 @@ def test_let_inlining():
("let %x = (0, 0);\n"
"%x")


def test_zeros():
x = relay.op.zeros([], "float32")
astext(x)


def test_cast():
data = relay.var('data', dtype='float32')
fp16_cast = relay.cast(data, dtype='float16')
cast_func = relay.Function(relay.analysis.free_vars(fp16_cast), fp16_cast)
astext(cast_func)


if __name__ == "__main__":
do_print[0] = True
test_lstm()
Expand All @@ -233,3 +247,4 @@ def test_zeros():
test_let_if_scope()
test_variable_name()
test_call_node_order()
test_cast()

0 comments on commit 616f1e6

Please sign in to comment.