From 54ac6902e3ec1bcdc79e8d2bc68e778ec4c76519 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Aug 2020 09:13:42 -0700 Subject: [PATCH] [Parser] Add support for parsing the any dimension. (#6277) * Add case for any dimensions * Fix second test case --- src/parser/parser.cc | 5 +++-- tests/python/relay/test_ir_parser.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 71d4304ca64d..8055d9138235 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1502,6 +1502,8 @@ class Parser { tvm::PrimExpr dim; if (Peek()->token_type == TokenType::kMetaReference) { dim = Downcast(ParseMetaRef()); + } else if (WhenMatch(TokenType::kQuestion)) { + dim = tvm::tir::Any(); } else { dim = Downcast(Match(TokenType::kInteger)->data); } @@ -1585,8 +1587,7 @@ class Parser { return ParseNonPrimitiveType(tok); } } - } - if (WhenMatch(TokenType::kUnderscore)) { + } else if (WhenMatch(TokenType::kUnderscore)) { return IncompleteType(); } else { this->diag_ctx->EmitFatal(Diagnostic::Error(tok->span) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 3fcc7dab5bcd..6d581b6e74db 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -591,6 +591,16 @@ def test_tensor_type(): ) ) + assert_parses_as( + "let %_ : Tensor[(?, 1), float32] = (); ()", + relay.Let( + relay.Var("_", relay.TensorType((tvm.tir.Any(), 1), "float32")), + UNIT, + UNIT + ) + ) + + def test_function_type(): assert_parses_as( @@ -678,6 +688,24 @@ def test_adt_defn(): mod ) +def test_adt_any(): + code = """ + type my_dtype { + my_cons(Tensor[(?, 1), uint16]), + } + """ + mod = parse_module(code) + items = mod.type_definitions.items() + global_type_var, type_data = items[0] + assert global_type_var.name_hint == "my_dtype" + ctors = type_data.constructors + assert len(ctors) == 1 + my_cons = ctors[0] + assert my_cons.name_hint == "my_cons" + ty_shape = my_cons.inputs[0].shape + assert isinstance(ty_shape[0], tvm.tir.Any) + assert ty_shape[1] == 1 + def test_empty_adt_defn(): mod = tvm.IRModule()