diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 3e559df62ead..f7024fe456a0 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -77,7 +77,8 @@ def __call__(self, args, attrs, type_args): try: return expr.Call(self.operator, args, attrs, type_args) except Exception: - raise Exception(str(self.operator) + " " + str(attrs)) + raise Exception("Operator {} is not registered. It's attributes are {}" + .format(self.operator, attrs)) class FuncOp(OpWrapper): """Convert the attrs, call the python function with the attrs passed in as keyword arguments. @@ -132,6 +133,7 @@ def __call__(self, args, attrs, type_args): "nn.dropout": op.nn.dropout_raw, "zeros": op.zeros, "split": op.split, + "cast": op.cast } TYPE_PREFIXES = [ diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index b55261cb5b58..c6f59d9438af 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -169,19 +169,23 @@ def test_inception_v3(): net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1) astext(net) + def test_squeezenet(): for version in ['1.0', '1.1']: net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) astext(net) + def test_vgg(): net, params = tvm.relay.testing.vgg.get_workload(batch_size=1) astext(net) + def test_densenet(): net, params = tvm.relay.testing.densenet.get_workload(batch_size=1) astext(net) + def test_call_node_order(): x = relay.var("x") y = relay.var("y") @@ -196,6 +200,7 @@ def test_call_node_order(): "};\n" "%2(%1)") + def test_let_inlining(): tup = relay.Tuple([relay.const(0), relay.const(0)]) x = relay.var("x") @@ -208,10 +213,19 @@ def test_let_inlining(): ("let %x = (0, 0);\n" "%x") + def test_zeros(): x = relay.op.zeros([], "float32") astext(x) + +def test_cast(): + data = relay.var('data', dtype='float32') + fp16_cast = relay.cast(data, dtype='float16') + cast_func = relay.Function(relay.analysis.free_vars(fp16_cast), fp16_cast) + astext(cast_func) + + if __name__ == "__main__": do_print[0] = True test_lstm() @@ -233,3 +247,4 @@ def test_zeros(): test_let_if_scope() test_variable_name() test_call_node_order() + test_cast()