diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index ca0b7163aa66..1e897b5444c9 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -41,18 +41,21 @@ namespace parser { * source of a TVM program. */ struct Source { + /*! \brief The source name. */ + SourceName source_name; + /*! \brief The raw source. */ std::string source; /*! \brief A mapping of line breaks into the raw source. */ std::vector> line_map; /*! \brief An empty source. */ - Source() : source(), line_map() {} + Source() : source_name(), source(), line_map() {} /*! \brief Construct a source from a string. */ - TVM_DLL explicit Source(const std::string& source); + TVM_DLL explicit Source(const SourceName& src_name, const std::string& source); - TVM_DLL Source(const Source& source) : source(source.source), line_map(source.line_map) {} + TVM_DLL Source(const Source& source) : source_name(source.source_name), source(source.source), line_map(source.line_map) {} /*! \brief Generate an error message at a specific line and column with the * annotated message. diff --git a/src/parser/parser.cc b/src/parser/parser.cc index dcc98169c191..a68d69cd5a0d 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -517,9 +517,29 @@ class Parser { return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); } - Object ParseMetaRef() { - Consume(TokenType::kMetaReference); - LOG(FATAL) << "implement me"; + ObjectRef ParseMetaRef() { + auto meta_ref = Match(TokenType::kMetaReference); + Call ref = Downcast(meta_ref->data); + auto attrs = ref->attrs.as(); + auto type_key = attrs->node_type_key; + auto index = attrs->node_index; + auto it = this->meta_table.find(type_key); + if (it != this->meta_table.end()) { + auto nodes = (*it).second; + if (index < nodes.size()) { + return nodes[index]; + } else { + this->diag_ctx->Emit( + Diagnostic::Error(meta_ref->span) + << "the node index `" << index << "` is out of bounds for `" << type_key << "`"); + return ObjectRef(); + } + } else { + this->diag_ctx->Emit( + Diagnostic::Error(meta_ref->span) + << "no entry in the meta table for `" << type_key << "`"); + return ObjectRef(); + } } /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and * ending with a stop token. @@ -607,8 +627,7 @@ class Parser { auto mod = IRModule({}, types); for (auto func : defs.funcs) { - auto function = ExpandMetaRefs(metadata, func.function); - mod->Add(func.global, function); + mod->Add(func.global, func.function); } return mod; @@ -801,8 +820,14 @@ class Parser { case TokenType::kFreeVar: { Consume(TokenType::kFreeVar); auto var_token = Match(TokenType::kLocal); - Match(TokenType::kColon); - auto type = ParseType(); + + Type type; + if (WhenMatch(TokenType::kColon)) { + type = ParseType(); + } else { + type = IncompleteType(); + } + BindFreeVar(var_token.ToString(), type); break; } @@ -950,7 +975,7 @@ class Parser { /*! Parse a function definition without a leading keyword or identifier. * - * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN, UN) -> Ret { body }. + * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. */ Function ParseFunctionDef() { DLOG(INFO) << "Parser::ParseFunctionDef"; @@ -968,6 +993,8 @@ class Parser { }); } + Map raw_attrs; + auto params = ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { auto token = Match(TokenType::kLocal); @@ -977,6 +1004,16 @@ class Parser { type = ParseType(); } return BindVar(string, type); + }, [&] { + auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; + auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; + + if (is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; + } + + return false; }); Type ret_type; @@ -990,7 +1027,12 @@ class Parser { PopTypeScopes(1); PopScopes(1); - return relay::Function(params, body, ret_type, generics); + // TODO(@jroesch): attributes should never be null, they should always be empty. + if (raw_attrs.size()) { + return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + } else { + return relay::Function(params, body, ret_type, generics); + } } /*! \brief Parse an if-expression. */ @@ -1170,6 +1212,22 @@ class Parser { return ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseAttributeValue(); }); } + case TokenType::kOpenParen: { + // TODO(@jroesch: need to figure out bracket vs. sequence) + // return ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + // [&]() { return ParseAttributeValue(); }); + return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, [&]() { return ParseAttributeValue(); }); + } + // TODO(@jroesch): not sure about this being the right way to handle nulls. + case TokenType::kIdentifier: { + if (auto text = next->data.as()) { + std::string id = GetRef(text); + if (id == "nullptr") { + Match(TokenType::kIdentifier); + return ObjectRef(); + } + } + } default: return ParseAtomicExpr(); } @@ -1278,6 +1336,7 @@ class Parser { } Expr GetOp(const std::string& op_name, const Token& tok) { + DLOG(INFO) << "op_name=" << op_name << " token=" << tok; try { return Op::Get(op_name); } catch (dmlc::Error e) { @@ -1335,6 +1394,7 @@ class Parser { return Expr(ctor.value()); } else { auto idents = ParseHierName(); + CHECK_NE(idents.size(), 0); std::stringstream op_name; int i = 0; int periods = idents.size() - 1; @@ -1354,8 +1414,6 @@ class Parser { } case TokenType::kMetaReference: { return Downcast(ParseMetaRef()); - Consume(TokenType::kMetaReference); - return Downcast(next->data); } case TokenType::kFn: { Consume(TokenType::kFn); @@ -1408,7 +1466,8 @@ class Parser { Array ParseHierName() { Array idents; while (Peek()->token_type == TokenType::kIdentifier) { - idents.push_back(Peek().ToString()); + auto name = Peek().ToString(); + idents.push_back(name); Consume(TokenType::kIdentifier); if (Peek()->token_type == TokenType::kPeriod) { @@ -1426,8 +1485,14 @@ class Parser { Array ParseShape() { auto dims = ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { - auto tok = Match(TokenType::kInteger); - return Downcast(tok->data); + tvm::PrimExpr dim; + if (Peek()->token_type == TokenType::kMetaReference) { + dim = Downcast(ParseMetaRef()); + } else { + dim = Downcast(Match(TokenType::kInteger)->data); + } + + return dim; }); return dims; } @@ -1565,10 +1630,12 @@ class Parser { IRModule ParseModule(std::string file_name, std::string file_content) { DLOG(INFO) << "ParseModule"; SourceName src_name = SourceName::Get(file_name); - Source src(file_content); + Source src(src_name, file_content); DiagnosticContext ctx(src); - auto tokens = Tokenize(&ctx, src_name, file_content); - Parser parser(&ctx, src_name, tokens, DefaultOpTable(), Source(file_content)); + auto tokens_and_table = Tokenize(&ctx, src_name, file_content); + auto tokens = tokens_and_table.first; + auto meta_data_table = tokens_and_table.second; + Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata()); auto mod = parser.ParseModule(); // NB(@jroesch): it is very important that we render any errors before we procede // if there were any errors which allow the parser to procede we must render them @@ -1580,10 +1647,12 @@ IRModule ParseModule(std::string file_name, std::string file_content) { Expr ParseExpr(std::string file_name, std::string file_content) { DLOG(INFO) << "ParseExpr"; SourceName src_name = SourceName::Get(file_name); - Source src(file_content); + Source src(src_name, file_content); DiagnosticContext ctx(src); - auto tokens = Tokenize(&ctx, src_name, file_content); - Parser parser(&ctx, src_name, tokens, DefaultOpTable(), Source(file_content)); + auto tokens_and_table = Tokenize(&ctx, src_name, file_content); + auto tokens = tokens_and_table.first; + auto meta_data_table = tokens_and_table.second; + Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata()); parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index beb32da7126c..549f7d33738e 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -27,7 +27,7 @@ namespace tvm { namespace parser { /*! \brief Construct a source from a string. */ -Source::Source(const std::string& source) : source(source) { +Source::Source(const SourceName& src_name, const std::string& source) : source_name(src_name), source(source) { int index = 0; int length = 0; line_map.push_back({index, length}); diff --git a/src/parser/token.h b/src/parser/token.h index 86a26cbada52..3750ec568cc8 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -378,7 +378,12 @@ int64_t Token::ToNumber() const { return Downcast(this->operator-> std::string Token::ToString() const { return Downcast(this->operator->()->data); } Map> Token::ToMetadata() const { - return Downcast>>(this->operator->()->data); + ObjectRef data = this->operator->()->data; + if (data.defined()) { + return Downcast>>(data); + } else { + return Map>({}); + } } } // namespace parser diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index f500d4ac6d58..7357106da41c 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -533,12 +533,22 @@ struct Tokenizer { tokens() {} }; -std::vector Condense(const std::vector& tokens) { +std::vector Condense(const std::vector& tokens, Token* table) { std::vector out; + bool found_metadata = false; for (size_t i = 0; i < tokens.size(); i++) { auto current = tokens.at(i); switch (current->token_type) { + case TokenType::kMetadata: { + if (!found_metadata) { + found_metadata = true; + *table = current; + } else { + LOG(FATAL) << "duplicate metadata section"; + } + continue; + } case TokenType::kPercent: { auto next = tokens.at(i + 1); if (next->token_type == TokenType::kIdentifier) { @@ -602,15 +612,16 @@ std::vector Condense(const std::vector& tokens) { return out; } -std::vector Tokenize(DiagnosticContext* ctx, const SourceName& source_name, +std::pair, Token> Tokenize(DiagnosticContext* ctx, const SourceName& source_name, const std::string& source) { auto tokenizer = Tokenizer(ctx, source_name, source); tokenizer.Tokenize(); - auto tokens = Condense(tokenizer.tokens); + Token meta_table(Span(), TokenType::kUnknown, ObjectRef()); + auto tokens = Condense(tokenizer.tokens, &meta_table); for (auto token : tokens) { CHECK(token.defined()); } - return tokens; + return { tokens, meta_table }; } } // namespace parser diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 2ec24cc50c25..50b87d2b94b0 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -234,9 +234,10 @@ def test_vars(): assert op.name == "nn.global_avg_pool2d" def test_meta_ref(): - meta_op = parse_text("meta[type_key][1337]") - assert meta_op.attrs.node_type_key == "type_key" - assert meta_op.attrs.node_index == 1337 + with pytest.raises(tvm.error.DiagnosticError): + meta_op = parse_text("meta[type_key][1337]") + assert meta_op.attrs.node_type_key == "type_key" + assert meta_op.attrs.node_index == 1337 def test_let(): diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index cd677096b2e9..52551bf68e77 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -17,19 +17,18 @@ import tvm from tvm import te from tvm import relay -import tvm.relay.testing +from tvm.relay import testing import numpy as np from tvm.relay import Expr from tvm.relay.analysis import free_vars -do_print = [True] +DEBUG_PRINT = False -SEMVER = "#[version = \"0.0.4\"]\n" +SEMVER = "#[version = \"0.0.5\"]\n" def astext(program, unify_free_vars=False): text = program.astext() print(text) - if isinstance(program, Expr): roundtrip_program = tvm.parser.parse_expr(text) else: @@ -40,7 +39,7 @@ def astext(program, unify_free_vars=False): return text def show(text): - if do_print[0]: + if DEBUG_PRINT: print("---------------------------") print(text) @@ -137,55 +136,55 @@ def test_variable_name(): def test_mlp(): - net, params = tvm.relay.testing.mlp.get_workload(batch_size=1) + net, _ = tvm.relay.testing.mlp.get_workload(batch_size=1) astext(net) def test_resnet(): - net, params = tvm.relay.testing.resnet.get_workload(batch_size=1) + net, _ = tvm.relay.testing.resnet.get_workload(batch_size=1) astext(net) def test_mobilenet(): - net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1) + net, _ = tvm.relay.testing.mobilenet.get_workload(batch_size=1) astext(net) def test_dqn(): - net, params = tvm.relay.testing.dqn.get_workload(batch_size=1) + net, _ = tvm.relay.testing.dqn.get_workload(batch_size=1) astext(net) def test_dcgan(): - net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1) + net, _ = tvm.relay.testing.dcgan.get_workload(batch_size=1) astext(net) def test_lstm(): - net, params = tvm.relay.testing.lstm.get_workload(1, 1) + net, _ = tvm.relay.testing.lstm.get_workload(1, 1) astext(net) - net, params = tvm.relay.testing.lstm.get_workload(4, 4) + net, _ = tvm.relay.testing.lstm.get_workload(4, 4) astext(net) def test_inception_v3(): - net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1) + net, _ = tvm.relay.testing.inception_v3.get_workload(batch_size=1) astext(net) def test_squeezenet(): for version in ['1.0', '1.1']: - net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) + net, _ = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) astext(net) def test_vgg(): - net, params = tvm.relay.testing.vgg.get_workload(batch_size=1) + net, _ = tvm.relay.testing.vgg.get_workload(batch_size=1) astext(net) def test_densenet(): - net, params = tvm.relay.testing.densenet.get_workload(batch_size=1) + net, _ = tvm.relay.testing.densenet.get_workload(batch_size=1) astext(net) @@ -234,7 +233,7 @@ def @main[A]() -> fn (A, List[A]) -> List[A] { Cons } """ - mod = relay.fromtext(SEMVER + type_def_str + main_def_str) + mod = tvm.parser.parse(SEMVER + type_def_str + main_def_str) mod_str = str(mod) # ensure constructors are printed correctly in type definitions (with their # signature) and as exprs (without their signature) @@ -252,25 +251,5 @@ def test_null_attribute(): if __name__ == "__main__": - do_print[0] = True - test_lstm() - # test_zeros() - # test_meta_data() - # test_let_inlining() - # test_resnet() - # test_mobilenet() - # test_mlp() - # test_dqn() - # test_dcgan() - # test_squeezenet() - # test_inception_v3() - # test_vgg() - # test_densenet() - # test_func() - # test_env() - # test_call_attrs() - # test_let_if_scope() - # test_variable_name() - # test_call_node_order() - # test_unapplied_constructor() - # test_null_attribute() + import sys + pytext.argv(sys.argv)