From 6d1a38d5a53f719f70558416091fd377af644872 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 5 Aug 2020 22:55:12 -0700 Subject: [PATCH] WIP --- include/tvm/ir/span.h | 2 +- include/tvm/parser/source_map.h | 6 +- src/ir/span.cc | 12 +- src/parser/meta_ref.cc | 11 +- src/parser/meta_ref.h | 4 +- src/parser/op_table.h | 20 +- src/parser/parser.cc | 405 ++++++++++++--------- src/parser/token.h | 327 +++++++++-------- src/parser/tokenizer.h | 116 +++--- src/printer/relay_text_printer.cc | 2 +- src/printer/text_printer.h | 3 +- tests/python/relay/test_ir_text_printer.py | 64 ++-- 12 files changed, 508 insertions(+), 464 deletions(-) diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 4f1006ebcb8a..be8799eaca19 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -111,7 +111,7 @@ class SpanNode : public Object { class Span : public ObjectRef { public: - TVM_DLL Span(SourceName source, int line, int column, int end_line, int end_column); + TVM_DLL Span(SourceName source, int line, int end_line, int column, int end_column); /*! \brief Merge two spans into one which captures the combined regions. */ TVM_DLL Span Merge(const Span& other); diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 98583ec549ba..ca0b7163aa66 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -16,13 +16,13 @@ * specific language governing permissions and limitations * under the License. */ - -#ifndef TVM_PARSER_SOURCE_MAP_H_ -#define TVM_PARSER_SOURCE_MAP_H_ /*! * \file source_map.h * \brief A map from source names to source code. */ +#ifndef TVM_PARSER_SOURCE_MAP_H_ +#define TVM_PARSER_SOURCE_MAP_H_ + #include #include #include diff --git a/src/ir/span.cc b/src/ir/span.cc index 2a2601c3f3df..e936feae1723 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -65,8 +65,8 @@ Span::Span(SourceName source, int line, int column, int end_line, int end_column auto n = make_object(); n->source = std::move(source); n->line = line; - n->column = column; n->end_line = end_line; + n->column = column; n->end_column = end_column; data_ = std::move(n); } @@ -74,21 +74,21 @@ Span::Span(SourceName source, int line, int column, int end_line, int end_column Span Span::Merge(const Span& other) { CHECK((*this)->source == other->source); return Span((*this)->source, std::min((*this)->line, other->line), - std::min((*this)->column, other->column), std::max((*this)->end_line, other->end_line), + std::min((*this)->column, other->column), std::max((*this)->end_column, other->end_column)); } TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int column, - int end_line, int end_column) { - return Span(source, line, column, end_line, end_column); +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int end_line, int column, + int end_column) { + return Span(source, line, end_line, column, end_column); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source << ", " << node->line << ", " << node->column << ")"; + p->stream << "Span(" << node->source << ", " << node->line << ", " << node->end_line << ", " << node->column << ", " << node->end_column << ")"; }); } // namespace tvm diff --git a/src/parser/meta_ref.cc b/src/parser/meta_ref.cc index 2a81423b898b..d23892753c5f 100644 --- a/src/parser/meta_ref.cc +++ b/src/parser/meta_ref.cc @@ -35,6 +35,9 @@ namespace parser { using tvm::relay::transform::CreateFunctionPass; using tvm::transform::PassContext; +/* Set to arbitrary high number, since we should never schedule in normal pass manager flow. */ +static int kMetaExpandOptLevel = 1337; + TVM_REGISTER_NODE_TYPE(MetaRefAttrs); bool MetaRefRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -60,12 +63,6 @@ Expr MetaRef(std::string type_key, uint64_t node_index) { return Call(op, {}, Attrs(attrs), {}); } -// class MetaRefAttrExpander : AttrFunctor { -// ObjectRef VisitAttrDefault_(const Object* node) final { - -// } -// } - struct MetaRefExpander : public ExprMutator { MetaTable table; @@ -94,7 +91,7 @@ Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) { auto pass = CreateFunctionPass([&](Function func, IRModule module, PassContext ctx) { return ExpandMetaRefs(meta_table, func); }, - 1337, "ExpandMetaRefs", {}); + kMetaExpandOptLevel, "ExpandMetaRefs", {}); return pass(mod, PassContext::Create()); } diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index 40e3fdbb7a8b..481f334cb0fe 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -71,12 +71,12 @@ struct MetaRefAttrs : public tvm::AttrsNode { * of the program. * * \param type_key The type key of the object in the meta section. - * \param kind The index into that subfield. + * \param node_index The index into that subfield. * \returns The meta table reference. */ Expr MetaRef(std::string type_key, uint64_t node_index); -relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& mod); +relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func); IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod); } // namespace parser diff --git a/src/parser/op_table.h b/src/parser/op_table.h index 5af10a0590b8..050904f23280 100644 --- a/src/parser/op_table.h +++ b/src/parser/op_table.h @@ -80,16 +80,16 @@ struct OperatorTable { OperatorTable DefaultOpTable() { return OperatorTable( - {Rule({TokenType::Star}, Op::Get("multiply"), 12, 2, true), - Rule({TokenType::Division}, Op::Get("divide"), 12, 2, true), - Rule({TokenType::Plus}, Op::Get("add"), 10, 2, true), - Rule({TokenType::Minus}, Op::Get("subtract"), 10, 2, true), - Rule({TokenType::LAngle}, Op::Get("less"), 8, 2, true), - Rule({TokenType::LAngle, TokenType::Equal}, Op::Get("less_equal"), 8, 2, true), - Rule({TokenType::RAngle}, Op::Get("greater"), 8, 2, true), - Rule({TokenType::RAngle, TokenType::Equal}, Op::Get("greater_equal"), 8, 2, true), - Rule({TokenType::Equal, TokenType::Equal}, Op::Get("equal"), 7, 2, true), - Rule({TokenType::Bang, TokenType::Equal}, Op::Get("not_equal"), 7, 2, true)}); + {Rule({TokenType::kStar}, Op::Get("multiply"), 12, 2, true), + Rule({TokenType::kDivision}, Op::Get("divide"), 12, 2, true), + Rule({TokenType::kPlus}, Op::Get("add"), 10, 2, true), + Rule({TokenType::kMinus}, Op::Get("subtract"), 10, 2, true), + Rule({TokenType::kLAngle}, Op::Get("less"), 8, 2, true), + Rule({TokenType::kLAngle, TokenType::kEqual}, Op::Get("less_equal"), 8, 2, true), + Rule({TokenType::kRAngle}, Op::Get("greater"), 8, 2, true), + Rule({TokenType::kRAngle, TokenType::kEqual}, Op::Get("greater_equal"), 8, 2, true), + Rule({TokenType::kEqual, TokenType::kEqual}, Op::Get("equal"), 7, 2, true), + Rule({TokenType::kBang, TokenType::kEqual}, Op::Get("not_equal"), 7, 2, true)}); } } // namespace parser diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 368cd0f1afd2..002ff8da5ffe 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -111,6 +111,7 @@ template class ScopeStack { private: std::vector> scope_stack; + std::unordered_map free_vars; public: /*! \brief Adds a variable binding to the current scope. */ @@ -121,6 +122,10 @@ class ScopeStack { this->scope_stack.back().name_map.insert({name, value}); } + void AddFreeVar(const std::string& name, const T& value) { + free_vars.insert({name, value}); + } + /*! \brief Looks up a variable name in the scope stack returning the matching variable * in most recent scope. */ T Lookup(const std::string& name) { @@ -130,6 +135,13 @@ class ScopeStack { return it->second; } } + + // Check if we bound a free variable declaration. + auto it = free_vars.find(name); + if (it != free_vars.end()) { + return it->second; + } + return T(); } @@ -265,10 +277,10 @@ class Parser { // For now we ignore all whitespace tokens and comments. // We can tweak this behavior later to enable white space sensitivity in the parser. while (pos < static_cast(tokens.size()) && ignore_whitespace && - (tokens.at(pos)->token_type == TokenType::Whitespace || - tokens.at(pos)->token_type == TokenType::Newline || - tokens.at(pos)->token_type == TokenType::LineComment || - tokens.at(pos)->token_type == TokenType::Comment)) { + (tokens.at(pos)->token_type == TokenType::kWhitespace || + tokens.at(pos)->token_type == TokenType::kNewline || + tokens.at(pos)->token_type == TokenType::kLineComment || + tokens.at(pos)->token_type == TokenType::kComment)) { pos++; } @@ -368,6 +380,17 @@ class Parser { return var; } + /*! \brief Bind a local variable in the expression scope. + * + * "x" -> Var("x"), these are needed to map from the raw string names + * to unique variable nodes. + */ + Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) { + auto var = Var(name, type_annotation); + this->expr_scopes.AddFreeVar(name, var); + return var; + } + /*! \brief Bind a type variable in the type scope. * * "A" -> TypeVar("A", ...), these are needed to map from raw string names @@ -386,8 +409,8 @@ class Parser { Var LookupLocal(const Token& local) { auto var = this->expr_scopes.Lookup(local.ToString()); if (!var.defined()) { - diag_ctx->Emit({DiagnosticLevel::Error, local->span, - "this local variable has not been previously declared"}); + diag_ctx->Emit(Diagnostic::Error(local->span) << + "this local variable has not been previously declared"); } return var; } @@ -399,9 +422,8 @@ class Parser { TypeVar LookupTypeVar(const Token& ident) { auto var = this->type_scopes.Lookup(ident.ToString()); if (!var.defined()) { - diag_ctx->Emit( - {DiagnosticLevel::Error, ident->span, - "this type variable has not been previously declared anywhere, perhaps a typo?"}); + diag_ctx->Emit(Diagnostic::Error(ident->span) + << "this type variable has not been previously declared anywhere, perhaps a typo?"); } return var; } @@ -428,7 +450,7 @@ class Parser { /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ NDArray NumberToNDArray(const Token& token) { - if (token->token_type == TokenType::Integer) { + if (token->token_type == TokenType::kInteger) { DLContext ctx = {DLDeviceType::kDLCPU, 0}; auto dtype = String2DLDataType("int32"); auto data = NDArray::Empty({}, dtype, ctx); @@ -437,7 +459,7 @@ class Parser { int64_t value = Downcast(token->data); array[0] = (int32_t)value; return data; - } else if (token->token_type == TokenType::Float) { + } else if (token->token_type == TokenType::kFloat) { DLContext ctx = {DLDeviceType::kDLCPU, 0}; auto dtype = String2DLDataType("float32"); auto data = NDArray::Empty({}, dtype, ctx); @@ -479,13 +501,13 @@ class Parser { /*! \brief Parse `(` parser() `)`. */ template R Parens(std::function parser) { - return Bracket(TokenType::OpenParen, TokenType::CloseParen, parser); + return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser); } /*! \brief Parse `{` parser() `}`. */ template R Block(std::function parser) { - return Bracket(TokenType::LCurly, TokenType::RCurly, parser); + return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); } /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and @@ -502,7 +524,7 @@ class Parser { template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, std::function before_stop = nullptr) { - DLOG(INFO) << "Parser::ParseSequence: start=" << start << "sep=" << sep << "stop=" << stop; + DLOG(INFO) << "Parser::ParseSequence: start=" << ToString(start) << "sep=" << ToString(sep) << "stop=" << ToString(stop); Match(start); // This is for the empty arguments list case, if we have token stream @@ -522,12 +544,6 @@ class Parser { auto data = parse(); Array elements = {data}; - // parse '(' expr ','? ')' - // if we are at the end invoke leftover parser - // if (Peek()->token_type == sep && before_stop) { - // before_stop(); - // } - if (WhenMatch(stop)) { return elements; // parse '( expr ',' * ')' @@ -569,7 +585,7 @@ class Parser { // Parse the metadata section at the end. auto metadata = ParseMetadata(); - Match(TokenType::EndOfFile); + Match(TokenType::kEndOfFile); Map funcs; Map types; @@ -589,8 +605,8 @@ class Parser { /*! \brief Parse the semantic versioning header. */ SemVer ParseSemVer(bool required = true) { - if (Peek()->token_type == TokenType::Version) { - auto version = Match(TokenType::Version); + if (Peek()->token_type == TokenType::kVersion) { + auto version = Match(TokenType::kVersion); // TODO(@jroesch): we currently only support 0.0.5. if (version.ToString() != "\"0.0.5\"") { this->diag_ctx->Emit(DiagnosticBuilder(DiagnosticLevel::Error, version->span) @@ -612,9 +628,9 @@ class Parser { while (true) { auto next = Peek(); switch (next->token_type) { - case TokenType::Defn: { - Consume(TokenType::Defn); - auto global_tok = Match(TokenType::Global); + case TokenType::kDefn: { + Consume(TokenType::kDefn); + auto global_tok = Match(TokenType::kGlobal); auto global_name = global_tok.ToString(); auto global = GlobalVar(global_name); try { @@ -628,12 +644,12 @@ class Parser { defs.funcs.push_back(GlobalFunc(global, func)); continue; } - case TokenType::TypeDef: { + case TokenType::kTypeDef: { defs.types.push_back(ParseTypeDef()); continue; } - case TokenType::Extern: { - Consume(TokenType::Extern); + case TokenType::kExtern: { + Consume(TokenType::kExtern); auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { diag_ctx->Emit({DiagnosticLevel::Error, next->span, @@ -650,9 +666,9 @@ class Parser { /*! \brief Parse zero or more Relay type definitions. */ TypeData ParseTypeDef() { // Match the `type` keyword. - Match(TokenType::TypeDef); + Match(TokenType::kTypeDef); // Parse the type's identifier. - auto type_tok = Match(TokenType::Identifier); + auto type_tok = Match(TokenType::kIdentifier); auto type_id = type_tok.ToString(); auto type_global = tvm::GlobalTypeVar(type_id, TypeKind::kAdtHandle); @@ -667,33 +683,33 @@ class Parser { Array generics; bool should_pop = false; - if (Peek()->token_type == TokenType::LSquare) { + if (Peek()->token_type == TokenType::kLSquare) { // If we have generics we need to add a type scope. PushTypeScope(); should_pop = true; generics = - ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() { - auto type_var_name = Match(TokenType::Identifier).ToString(); + ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { + auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); } Array ctors; - if (Peek()->token_type == TokenType::LCurly) { + if (Peek()->token_type == TokenType::kLCurly) { // Parse the list of constructors. ctors = ParseSequence( - TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&]() { + TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() { // First match the name of the constructor. - auto ctor_tok = Match(TokenType::Identifier); + auto ctor_tok = Match(TokenType::kIdentifier); auto ctor_name = ctor_tok.ToString(); Constructor ctor; // Match the optional field list. - if (Peek()->token_type != TokenType::OpenParen) { + if (Peek()->token_type != TokenType::kOpenParen) { ctor = tvm::Constructor(ctor_name, {}, type_global); } else { auto arg_types = - ParseSequence(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, + ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { return ParseType(); }); ctor = tvm::Constructor(ctor_name, arg_types, type_global); } @@ -757,12 +773,12 @@ class Parser { switch (next->token_type) { // For graph or let, match first rhs, then invoke ParseBindingExpr // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue - case TokenType::LCurly: { + case TokenType::kLCurly: { // NB: Might need to optimize to remove deep recursion. // Stack should only grow proportionally to the number of // nested scopes. // Parses `{` expression `}`. - auto block = Bracket(TokenType::LCurly, TokenType::RCurly, [&]() { + auto block = Bracket(TokenType::kLCurly, TokenType::kRCurly, [&]() { PushScope(); auto expr = ParseExpr(); PopScopes(1); @@ -771,24 +787,32 @@ class Parser { exprs.push_back(block); break; } + case TokenType::kFreeVar: { + Consume(TokenType::kFreeVar); + auto var_token = Match(TokenType::kLocal); + Match(TokenType::kColon); + auto type = ParseType(); + BindFreeVar(var_token.ToString(), type); + break; + } // Parses `let ...`; - case TokenType::Let: + case TokenType::kLet: exprs.push_back(ParseBindingExpr()); break; - case TokenType::Match: - case TokenType::PartialMatch: { - bool is_total = next->token_type == TokenType::Match; + case TokenType::kMatch: + case TokenType::kPartialMatch: { + bool is_total = next->token_type == TokenType::kMatch; Consume(next->token_type); exprs.push_back(ParseMatch(is_total)); break; } - case TokenType::If: { + case TokenType::kIf: { exprs.push_back(ParseIf()); break; } // %x ... - case TokenType::Graph: - if (Lookahead(2)->token_type == TokenType::Equal) { + case TokenType::kGraph: + if (Lookahead(2)->token_type == TokenType::kEqual) { exprs.push_back(ParseBindingExpr()); break; } @@ -799,7 +823,7 @@ class Parser { } } - if (!WhenMatch(TokenType::Semicolon)) { + if (!WhenMatch(TokenType::kSemicolon)) { break; } } @@ -853,34 +877,34 @@ class Parser { while (true) { auto next = Peek(); - if (next->token_type == TokenType::Graph && Lookahead(2)->token_type == TokenType::Equal) { - Match(TokenType::Graph); - Match(TokenType::Equal); + if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) { + Match(TokenType::kGraph); + Match(TokenType::kEqual); auto val = this->ParseExprBinOp(); - Match(TokenType::Semicolon); + Match(TokenType::kSemicolon); AddGraphBinding(next, val); - } else if (next->token_type == TokenType::Let) { + } else if (next->token_type == TokenType::kLet) { // Parse the 'let'. - Consume(TokenType::Let); + Consume(TokenType::kLet); // Parse the local '%'. - auto local_tok = Match(TokenType::Local); + auto local_tok = Match(TokenType::kLocal); auto string = local_tok.ToString(); // Parse the optional type annotation (':' ). Type type; - if (WhenMatch(TokenType::Colon)) { + if (WhenMatch(TokenType::kColon)) { type = ParseType(); } auto var = BindVar(string, type); // Parse the '='; - Match(TokenType::Equal); + Match(TokenType::kEqual); // Parse the body, and the ';'. auto val = this->ParseExprBinOp(); - Consume(TokenType::Semicolon); + Consume(TokenType::kSemicolon); // Add the bindings to the local data structure. bindings.push_back({var, val}); @@ -923,30 +947,30 @@ class Parser { PushTypeScope(); Array generics; - if (Peek()->token_type == TokenType::LSquare) { + if (Peek()->token_type == TokenType::kLSquare) { // If we have generics we need to add a type scope. PushTypeScope(); generics = - ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() { - auto type_var_name = Match(TokenType::Identifier).ToString(); + ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { + auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); } auto params = - ParseSequence(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, [&]() { - auto token = Match(TokenType::Local); + ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { + auto token = Match(TokenType::kLocal); auto string = token.ToString(); Type type; - if (WhenMatch(TokenType::Colon)) { + if (WhenMatch(TokenType::kColon)) { type = ParseType(); } return BindVar(string, type); }); Type ret_type; - if (WhenMatch(TokenType::Minus)) { - Match(TokenType::RAngle); + if (WhenMatch(TokenType::kMinus)) { + Match(TokenType::kRAngle); ret_type = ParseType(); } @@ -961,7 +985,7 @@ class Parser { /*! \brief Parse an if-expression. */ Expr ParseIf() { DLOG(INFO) << "Parser::ParseIf"; - Consume(TokenType::If); + Consume(TokenType::kIf); auto guard = Parens([&] { return ParseExpr(); }); auto true_branch = Block([&] { @@ -971,7 +995,7 @@ class Parser { return expr; }); - Match(TokenType::Else); + Match(TokenType::kElse); auto false_branch = Block([&] { this->PushScope(); @@ -985,7 +1009,7 @@ class Parser { /* This factors parsing a list of patterns for both tuples, and constructors. */ Array ParsePatternList() { - return ParseSequence(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, + return ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&] { return ParsePattern(); }); } @@ -1000,24 +1024,24 @@ class Parser { DLOG(INFO) << "Parser::ParsePattern"; auto next = Peek(); switch (next->token_type) { - case TokenType::Underscore: { - Match(TokenType::Underscore); + case TokenType::kUnderscore: { + Match(TokenType::kUnderscore); return PatternWildcard(); } - case TokenType::Local: { - auto id = Match(TokenType::Local); + case TokenType::kLocal: { + auto id = Match(TokenType::kLocal); Type type_annotation; - if (WhenMatch(TokenType::Colon)) { + if (WhenMatch(TokenType::kColon)) { type_annotation = ParseType(); } auto var = BindVar(id.ToString(), type_annotation); return PatternVar(var); } - case TokenType::Identifier: { - auto id = Match(TokenType::Identifier); + case TokenType::kIdentifier: { + auto id = Match(TokenType::kIdentifier); auto ctor = ctors.Get(id.ToString()); CHECK(ctor) << "undefined identifier"; - if (Peek()->token_type == TokenType::OpenParen) { + if (Peek()->token_type == TokenType::kOpenParen) { auto fields = ParsePatternList(); return PatternConstructor(ctor.value(), fields); } else { @@ -1032,8 +1056,8 @@ class Parser { Clause ParseMatchArm() { PushScope(); auto pattern = ParsePattern(); - Match(TokenType::Equal); - Consume(TokenType::RAngle); + Match(TokenType::kEqual); + Consume(TokenType::kRAngle); auto expr = ParseExpr(); PopScopes(1); return Clause(pattern, expr); @@ -1043,7 +1067,7 @@ class Parser { Expr scrutinee = ParseExpr(); Array clauses = ParseSequence( - TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&] { return ParseMatchArm(); }); + TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&] { return ParseMatchArm(); }); return relay::Match(scrutinee, clauses, is_total); } @@ -1126,13 +1150,13 @@ class Parser { DLOG(INFO) << "Parser::ParseAttributeValue"; auto next = Peek(); switch (next->token_type) { - case TokenType::Float: - case TokenType::Integer: - case TokenType::Boolean: - case TokenType::StringLiteral: + case TokenType::kFloat: + case TokenType::kInteger: + case TokenType::kBoolean: + case TokenType::kStringLiteral: return Match(next->token_type)->data; - case TokenType::LSquare: { - return ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, + case TokenType::kLSquare: { + return ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseAttributeValue(); }); } default: @@ -1143,58 +1167,67 @@ class Parser { Map ParseAttrs() { DLOG(INFO) << "Parser::ParseAttrs"; Map kwargs; - while (Peek()->token_type == TokenType::Identifier) { - auto key = Match(TokenType::Identifier).ToString(); - Match(TokenType::Equal); + while (Peek()->token_type == TokenType::kIdentifier) { + auto key = Match(TokenType::kIdentifier).ToString(); + Match(TokenType::kEqual); // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. auto value = ParseAttributeValue(); + // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text format is bad. kwargs.Set(key, value); - WhenMatch(TokenType::Comma); + WhenMatch(TokenType::kComma); } DLOG(INFO) << "Parser::ParseAttrs: kwargs=" << kwargs; return kwargs; } Expr ParseCallArgs(Expr op) { - DLOG(INFO) << "Parser::ParseCallArgs"; - Map raw_attrs; - std::string op_key; - bool is_op = false; - - if (auto op_node = op.as()) { - is_op = true; - op_key = op_node->attrs_type_key; - } + try { + DLOG(INFO) << "Parser::ParseCallArgs"; + Map raw_attrs; + std::string op_key; + bool is_op = false; + + if (auto op_node = op.as()) { + is_op = true; + op_key = op_node->attrs_type_key; + } - if (Peek()->token_type == TokenType::OpenParen) { - Array args = ParseSequence( - TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, - [&] { return ParseExpr(); }, - [&] { - auto is_ident = Lookahead(1)->token_type == TokenType::Identifier; - auto next_is_equal = Lookahead(2)->token_type == TokenType::Equal; - - if (is_op && is_ident && next_is_equal) { - raw_attrs = ParseAttrs(); - return true; - } + if (Peek()->token_type == TokenType::kOpenParen) { + Array args = ParseSequence( + TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + [&] { return ParseExpr(); }, + [&] { + auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; + auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; + + if (is_op && is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; + } - return false; - }); + return false; + }); - Attrs attrs; + Attrs attrs; - if (is_op && op_key.size()) { - // raw_attrs.Set("type_key", tvm::String("hello")); - auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); - CHECK(attr_obj.defined()); - attrs = Downcast(attr_obj); - } + if (is_op && op_key.size()) { + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); + CHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } - return Expr(Call(op, args, attrs, {})); - } else { - return Expr(); + return Expr(Call(op, args, attrs, {})); + } else { + return Expr(); + } + } catch (...) { + // TODO(@jroesch): AttrErrors should have fields + this->diag_ctx->Emit( + Diagnostic::Error(Peek()->span)); + // << err.what()); } + + return Expr(); } Expr ParseCallExpr() { @@ -1205,12 +1238,20 @@ class Parser { // // NB(@jroesch): this seems like a hack but in order to parse curried functions // and avoid complex grammar we will parse multiple call lists in a row. - while (Peek()->token_type == TokenType::OpenParen) { - auto new_expr = ParseCallArgs(expr); - if (new_expr.defined()) { - expr = new_expr; - } else { - break; + while (Peek()->token_type == TokenType::kOpenParen) { + try { + auto new_expr = ParseCallArgs(expr); + + if (new_expr.defined()) { + expr = new_expr; + } else { + break; + } + } catch (...) { + // TODO(@jroesch): AttrErrors should have fields + this->diag_ctx->EmitFatal( + Diagnostic::Error(Peek()->span)); + // << err.what()); } } @@ -1241,29 +1282,29 @@ class Parser { auto expr = ConsumeWhitespace([this] { auto next = Peek(); switch (next->token_type) { - case TokenType::Integer: - case TokenType::Float: { + case TokenType::kInteger: + case TokenType::kFloat: { Consume(next->token_type); auto number = NumberToNDArray(next); Expr e = Constant(number, next->span); return e; } - case TokenType::Boolean: { - Consume(TokenType::Boolean); + case TokenType::kBoolean: { + Consume(TokenType::kBoolean); int value = Downcast(next->data); auto boolean = BooleanToNDarray(value); Expr e = Constant(boolean, next->span); return e; } // Parse a local of the form `%x`. - case TokenType::Local: { - Consume(TokenType::Local); + case TokenType::kLocal: { + Consume(TokenType::kLocal); return Expr(LookupLocal(next)); } // Parse a local of the form `@x`. - case TokenType::Global: { + case TokenType::kGlobal: { auto string = next.ToString(); - Consume(TokenType::Global); + Consume(TokenType::kGlobal); auto global = global_names.Get(string); if (!global) { // TODO(@jroesch): fix global's needing span information @@ -1276,10 +1317,10 @@ class Parser { } // Parse a local of the form `x`. // Right now we fail to parse `x.y`. - case TokenType::Identifier: { + case TokenType::kIdentifier: { auto ctor = ctors.Get(next.ToString()); if (ctor) { - Consume(TokenType::Identifier); + Consume(TokenType::kIdentifier); return Expr(ctor.value()); } else { auto idents = ParseHierName(); @@ -1296,37 +1337,37 @@ class Parser { return GetOp(op_name.str(), next); } } - case TokenType::Graph: { - Consume(TokenType::Graph); + case TokenType::kGraph: { + Consume(TokenType::kGraph); return LookupGraphBinding(next); } - case TokenType::MetaReference: { - Consume(TokenType::MetaReference); + case TokenType::kMetaReference: { + Consume(TokenType::kMetaReference); return Downcast(next->data); } - case TokenType::Fn: { - Consume(TokenType::Fn); + case TokenType::kFn: { + Consume(TokenType::kFn); return Expr(ParseFunctionDef()); } - case TokenType::OpenParen: { - Consume(TokenType::OpenParen); + case TokenType::kOpenParen: { + Consume(TokenType::kOpenParen); // parse '(' ')' - if (WhenMatch(TokenType::CloseParen)) { + if (WhenMatch(TokenType::kCloseParen)) { return Expr(Tuple(Array())); } else { auto expr = ParseExpr(); // parse '(' expr ')' - if (WhenMatch(TokenType::CloseParen)) { + if (WhenMatch(TokenType::kCloseParen)) { return expr; // parse '( expr ',' * ')' - } else if (WhenMatch(TokenType::Comma)) { + } else if (WhenMatch(TokenType::kComma)) { Array exprs = {expr}; while (true) { - if (WhenMatch(TokenType::CloseParen)) { + if (WhenMatch(TokenType::kCloseParen)) { break; } else { auto expr = ParseExpr(); - WhenMatch(TokenType::Comma); + WhenMatch(TokenType::kComma); exprs.push_back(expr); } } @@ -1343,8 +1384,8 @@ class Parser { } }); - if (WhenMatch(TokenType::Period)) { - auto index = Match(TokenType::Integer).ToNumber(); + if (WhenMatch(TokenType::kPeriod)) { + auto index = Match(TokenType::kInteger).ToNumber(); expr = relay::TupleGetItem(expr, index); } @@ -1354,12 +1395,12 @@ class Parser { /*! \brief Parse a hierarchical name. */ Array ParseHierName() { Array idents; - while (Peek()->token_type == TokenType::Identifier) { + while (Peek()->token_type == TokenType::kIdentifier) { idents.push_back(Peek().ToString()); - Consume(TokenType::Identifier); + Consume(TokenType::kIdentifier); - if (Peek()->token_type == TokenType::Period) { - Consume(TokenType::Period); + if (Peek()->token_type == TokenType::kPeriod) { + Consume(TokenType::kPeriod); continue; } else { break; @@ -1371,9 +1412,9 @@ class Parser { /*! \brief Parse a shape. */ Array ParseShape() { - auto dims = ParseSequence(TokenType::OpenParen, TokenType::Comma, - TokenType::CloseParen, [&]() { - auto tok = Match(TokenType::Integer); + auto dims = ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { + auto tok = Match(TokenType::kInteger); return Downcast(tok->data); }); return dims; @@ -1381,11 +1422,11 @@ class Parser { /*! \brief Parse a function type. */ Type ParseFunctionType() { - auto ty_params = ParseSequence(TokenType::OpenParen, TokenType::Comma, - TokenType::CloseParen, [&]() { return ParseType(); }); + auto ty_params = ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { return ParseType(); }); - Match(TokenType::Minus); - Match(TokenType::RAngle); + Match(TokenType::kMinus); + Match(TokenType::kRAngle); auto ret_type = ParseType(); return relay::FuncType(ty_params, ret_type, {}, {}); @@ -1406,8 +1447,8 @@ class Parser { CHECK(head_type.defined()) << "internal error: head type must be defined"; Array arg_types; - if (Peek()->token_type == TokenType::LSquare) { - arg_types = ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, + if (Peek()->token_type == TokenType::kLSquare) { + arg_types = ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseType(); }); } @@ -1426,21 +1467,21 @@ class Parser { Type ParseType() { auto tok = Peek(); - if (tok->token_type == TokenType::OpenParen) { - auto tys = ParseSequence(TokenType::OpenParen, TokenType::Comma, - TokenType::CloseParen, [&]() { return ParseType(); }); + if (tok->token_type == TokenType::kOpenParen) { + auto tys = ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { return ParseType(); }); return relay::TupleType(tys); - } else if (WhenMatch(TokenType::Fn)) { + } else if (WhenMatch(TokenType::kFn)) { return ParseFunctionType(); - } else if (WhenMatch(TokenType::Identifier)) { + } else if (WhenMatch(TokenType::kIdentifier)) { auto id = tok.ToString(); if (id == "Tensor") { - Match(TokenType::LSquare); + Match(TokenType::kLSquare); auto shape = ParseShape(); - Match(TokenType::Comma); - auto dtype_tok = Match(TokenType::Identifier); + Match(TokenType::kComma); + auto dtype_tok = Match(TokenType::kIdentifier); auto dtype = DataType(String2DLDataType(dtype_tok.ToString())); - Match(TokenType::RSquare); + Match(TokenType::kRSquare); return TensorType(shape, dtype); } else { auto ty = tok.ToString(); @@ -1454,7 +1495,7 @@ class Parser { } } } - if (WhenMatch(TokenType::Underscore)) { + if (WhenMatch(TokenType::kUnderscore)) { return IncompleteType(); } else { this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, tok->span) @@ -1467,7 +1508,7 @@ class Parser { R ConsumeWhitespace(std::function func) { auto old = this->ignore_whitespace; this->ignore_whitespace = true; - while (tokens[pos]->token_type == TokenType::Whitespace) { + while (tokens[pos]->token_type == TokenType::kWhitespace) { pos++; } auto res = func(); @@ -1476,8 +1517,8 @@ class Parser { } Map> ParseMetadata() { - if (Peek()->token_type == TokenType::Metadata) { - return Match(TokenType::Metadata).ToMetadata(); + if (Peek()->token_type == TokenType::kMetadata) { + return Match(TokenType::kMetadata).ToMetadata(); } else { return Map>(); } @@ -1534,7 +1575,7 @@ Expr ParseExpr(std::string file_name, std::string file_content) { parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); - parser.Match(TokenType::EndOfFile); + parser.Match(TokenType::kEndOfFile); // NB(@jroesch): it is very important that we render any errors before we procede // if there were any errors which allow the parser to procede we must render them // here. diff --git a/src/parser/token.h b/src/parser/token.h index 480872956b68..86a26cbada52 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -38,169 +38,172 @@ namespace parser { using namespace runtime; -enum TokenType { - CommentStart, - CommentEnd, - LineComment, - Comment, - Whitespace, - Newline, - StringLiteral, - Identifier, - Local, - Global, - Op, - Graph, - OpenParen, - CloseParen, - AtSymbol, - Percent, - Comma, - Period, - Equal, - Semicolon, - Colon, - Integer, - Float, - Division, - Boolean, - Plus, - Star, - Minus, - RAngle, - LAngle, - RCurly, - LCurly, - RSquare, - LSquare, - Bang, - At, - Question, - If, - Else, - Underscore, - Let, - Fn, - Defn, - TypeDef, - Extern, - Match, - PartialMatch, - Metadata, - MetaReference, - Version, - Unknown, - EndOfFile, - Null, +enum class TokenType { + kCommentStart, + kCommentEnd, + kLineComment, + kComment, + kWhitespace, + kNewline, + kStringLiteral, + kIdentifier, + kLocal, + kGlobal, + kOp, + kGraph, + kOpenParen, + kCloseParen, + kAtSymbol, + kPercent, + kComma, + kPeriod, + kEqual, + kSemicolon, + kColon, + kInteger, + kFloat, + kDivision, + kBoolean, + kPlus, + kStar, + kMinus, + kRAngle, + kLAngle, + kRCurly, + kLCurly, + kRSquare, + kLSquare, + kBang, + kAt, + kQuestion, + kIf, + kElse, + kUnderscore, + kLet, + kFn, + kDefn, + kTypeDef, + kExtern, + kMatch, + kPartialMatch, + kMetadata, + kMetaReference, + kFreeVar, + kVersion, + kUnknown, + kEndOfFile, + kNull, }; std::string ToString(const TokenType& token_type) { switch (token_type) { - case TokenType::CommentStart: + case TokenType::kCommentStart: return "CommentStart"; - case TokenType::CommentEnd: + case TokenType::kCommentEnd: return "CommentEnd"; - case TokenType::LineComment: + case TokenType::kLineComment: return "LineComment"; - case TokenType::Comment: + case TokenType::kComment: return "Comment"; - case TokenType::Whitespace: + case TokenType::kWhitespace: return "WhiteSpace"; - case TokenType::Newline: + case TokenType::kNewline: return "Newline"; - case TokenType::StringLiteral: + case TokenType::kStringLiteral: return "StringLiteral"; - case TokenType::Identifier: + case TokenType::kIdentifier: return "Identifier"; - case TokenType::Local: + case TokenType::kLocal: return "Local"; - case TokenType::Global: + case TokenType::kGlobal: return "Global"; - case TokenType::Graph: + case TokenType::kGraph: return "Graph"; - case TokenType::Op: + case TokenType::kOp: return "Op"; - case TokenType::OpenParen: + case TokenType::kOpenParen: return "OpenParen"; - case TokenType::CloseParen: + case TokenType::kCloseParen: return "CloseParen"; - case TokenType::AtSymbol: + case TokenType::kAtSymbol: return "AtSymbol"; - case TokenType::Percent: + case TokenType::kPercent: return "Percent"; - case TokenType::Comma: + case TokenType::kComma: return "Comma"; - case TokenType::Colon: + case TokenType::kColon: return "Colon"; - case TokenType::Semicolon: + case TokenType::kSemicolon: return "Semicolon"; - case TokenType::Period: + case TokenType::kPeriod: return "Period"; - case TokenType::Equal: + case TokenType::kEqual: return "Equal"; - case TokenType::Integer: + case TokenType::kInteger: return "Integer"; - case TokenType::Float: + case TokenType::kFloat: return "Float"; - case TokenType::Plus: + case TokenType::kPlus: return "Plus"; - case TokenType::Star: + case TokenType::kStar: return "Star"; - case TokenType::Minus: + case TokenType::kMinus: return "Minus"; - case TokenType::Division: + case TokenType::kDivision: return "Division"; - case TokenType::RAngle: + case TokenType::kRAngle: return "RAngle"; - case TokenType::LAngle: + case TokenType::kLAngle: return "LAngle"; - case TokenType::RCurly: + case TokenType::kRCurly: return "RCurly"; - case TokenType::LCurly: + case TokenType::kLCurly: return "LCurly"; - case TokenType::RSquare: + case TokenType::kRSquare: return "RSquare"; - case TokenType::LSquare: + case TokenType::kLSquare: return "LSquare"; - case TokenType::Bang: + case TokenType::kBang: return "Bang"; - case TokenType::Underscore: + case TokenType::kUnderscore: return "Underscore"; - case TokenType::At: + case TokenType::kAt: return "At"; - case TokenType::Let: + case TokenType::kLet: return "Let"; - case TokenType::If: + case TokenType::kIf: return "If"; - case TokenType::Else: + case TokenType::kElse: return "Else"; - case TokenType::Fn: + case TokenType::kFn: return "Fn"; - case TokenType::Defn: + case TokenType::kDefn: return "Defn"; - case TokenType::TypeDef: + case TokenType::kTypeDef: return "TypeDef"; - case TokenType::Extern: + case TokenType::kExtern: return "Extern"; - case TokenType::Match: + case TokenType::kMatch: return "Match"; - case TokenType::PartialMatch: + case TokenType::kPartialMatch: return "PartialMatch"; - case TokenType::Question: + case TokenType::kQuestion: return "Question"; - case TokenType::Boolean: + case TokenType::kBoolean: return "Boolean"; - case TokenType::Metadata: + case TokenType::kMetadata: return "Metadata"; - case TokenType::MetaReference: + case TokenType::kMetaReference: return "MetaReference"; - case TokenType::Version: + case TokenType::kFreeVar: + return "FreeVar"; + case TokenType::kVersion: return "Version"; - case TokenType::Unknown: + case TokenType::kUnknown: return "Unknown"; - case TokenType::EndOfFile: + case TokenType::kEndOfFile: return "EndOfFile"; - case TokenType::Null: + case TokenType::kNull: return "Null"; // Older compilers warn even though the above code is exhaustive. default: @@ -211,111 +214,113 @@ std::string ToString(const TokenType& token_type) { std::string Pretty(const TokenType& token_type) { switch (token_type) { - case TokenType::CommentStart: + case TokenType::kCommentStart: return "`/*`"; - case TokenType::CommentEnd: + case TokenType::kCommentEnd: return "`*/`"; - case TokenType::LineComment: + case TokenType::kLineComment: return "`//`"; - case TokenType::Comment: + case TokenType::kComment: return "comment"; - case TokenType::Whitespace: + case TokenType::kWhitespace: return "whitespace"; - case TokenType::Newline: + case TokenType::kNewline: return "newline"; - case TokenType::StringLiteral: + case TokenType::kStringLiteral: return "string literal"; - case TokenType::Identifier: + case TokenType::kIdentifier: return "identifier"; - case TokenType::Local: + case TokenType::kLocal: return "local variable"; - case TokenType::Global: + case TokenType::kGlobal: return "global variable"; - case TokenType::Graph: + case TokenType::kGraph: return "graph variable"; - case TokenType::Op: + case TokenType::kOp: return "operator"; - case TokenType::OpenParen: + case TokenType::kOpenParen: return "`(`"; - case TokenType::CloseParen: + case TokenType::kCloseParen: return "`)`"; - case TokenType::AtSymbol: + case TokenType::kAtSymbol: return "`@`"; - case TokenType::Percent: + case TokenType::kPercent: return "`%`"; - case TokenType::Comma: + case TokenType::kComma: return "`,`"; - case TokenType::Colon: + case TokenType::kColon: return "`:`"; - case TokenType::Semicolon: + case TokenType::kSemicolon: return "`;`"; - case TokenType::Period: + case TokenType::kPeriod: return "`.`"; - case TokenType::Equal: + case TokenType::kEqual: return "`=`"; - case TokenType::Integer: + case TokenType::kInteger: return "integer"; - case TokenType::Float: + case TokenType::kFloat: return "float"; - case TokenType::Plus: + case TokenType::kPlus: return "`+`"; - case TokenType::Star: + case TokenType::kStar: return "`*`"; - case TokenType::Minus: + case TokenType::kMinus: return "`-`"; - case TokenType::Division: + case TokenType::kDivision: return "`/`"; - case TokenType::RAngle: + case TokenType::kRAngle: return "`<`"; - case TokenType::LAngle: + case TokenType::kLAngle: return "`>`"; - case TokenType::RCurly: + case TokenType::kRCurly: return "`}`"; - case TokenType::LCurly: + case TokenType::kLCurly: return "`{`"; - case TokenType::RSquare: + case TokenType::kRSquare: return "`]`"; - case TokenType::LSquare: + case TokenType::kLSquare: return "`[`"; - case TokenType::Bang: + case TokenType::kBang: return "`!`"; - case TokenType::Underscore: + case TokenType::kUnderscore: return "`_`"; - case TokenType::At: + case TokenType::kAt: return "`@`"; - case TokenType::Let: + case TokenType::kLet: return "`let`"; - case TokenType::If: + case TokenType::kIf: return "`if`"; - case TokenType::Else: + case TokenType::kElse: return "`else`"; - case TokenType::Fn: + case TokenType::kFn: return "`fn`"; - case TokenType::Defn: + case TokenType::kDefn: return "`def`"; - case TokenType::TypeDef: + case TokenType::kTypeDef: return "`type`"; - case TokenType::Extern: + case TokenType::kExtern: return "`extern`"; - case TokenType::Boolean: + case TokenType::kBoolean: return "boolean"; - case TokenType::Metadata: + case TokenType::kMetadata: return "metadata section"; - case TokenType::MetaReference: + case TokenType::kMetaReference: return "`meta`"; - case TokenType::Match: + case TokenType::kFreeVar: + return "`free_var`"; + case TokenType::kMatch: return "`match`"; - case TokenType::PartialMatch: + case TokenType::kPartialMatch: return "`match?`"; - case TokenType::Question: + case TokenType::kQuestion: return "`?`"; - case TokenType::Unknown: + case TokenType::kUnknown: return "unknown"; - case TokenType::EndOfFile: + case TokenType::kEndOfFile: return "end of file"; - case TokenType::Null: + case TokenType::kNull: return "null"; - case TokenType::Version: + case TokenType::kVersion: return "version attribute"; // Older compilers warn even though the above code is exhaustive. default: @@ -366,7 +371,7 @@ Token::Token(Span span, TokenType token_type, ObjectRef data) { data_ = std::move(n); } -Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::Null); } +Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::kNull); } int64_t Token::ToNumber() const { return Downcast(this->operator->()->data); } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 0456ece4e293..f500d4ac6d58 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -66,9 +66,9 @@ bool IsIdentLetter(char c) { return '_' == c || ('a' <= c && c <= 'z') || ('A' < bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); } static std::unordered_map KEYWORD_TABLE = { - {"let", TokenType::Let}, {"fn", TokenType::Fn}, {"def", TokenType::Defn}, - {"if", TokenType::If}, {"else", TokenType::Else}, {"type", TokenType::TypeDef}, - {"match", TokenType::Match}, {"extern", TokenType::Extern}}; + {"let", TokenType::kLet}, {"fn", TokenType::kFn}, {"def", TokenType::kDefn}, + {"if", TokenType::kIf}, {"else", TokenType::kElse}, {"type", TokenType::kTypeDef}, + {"match", TokenType::kMatch}, {"extern", TokenType::kExtern}, {"free_var", TokenType::kFreeVar}}; struct Tokenizer { DiagnosticContext* diag_ctx; @@ -102,14 +102,14 @@ struct Tokenizer { Token NewToken(TokenType token_type, ObjectRef data = ObjectRef(), int lines = 0, int cols = 1) { auto span = - Span(this->source_name, this->line, this->col, this->line + lines, this->col + cols); + Span(this->source_name, this->line, this->line + lines, this->col, this->col + cols); return Token(span, token_type, data); } Span SpanFrom(int line, int column) { int end_line = this->line; int end_column = this->col; - return Span(this->source_name, line, column, end_line, end_column); + return Span(this->source_name, line, end_line, column, end_column); } enum CommentParserState { @@ -172,7 +172,7 @@ struct Tokenizer { if (is_float) { throw std::invalid_argument("is_float"); } - auto token = NewToken(TokenType::Integer); + auto token = NewToken(TokenType::kInteger); size_t index = 0; int value = std::stoi(number, &index); if (number.size() > index) { @@ -182,7 +182,7 @@ struct Tokenizer { token->data = tvm::Integer(value); return token; } catch (const std::invalid_argument& ia) { - auto token = NewToken(TokenType::Float); + auto token = NewToken(TokenType::kFloat); if (number.back() == 'f') { number.pop_back(); @@ -233,10 +233,8 @@ struct Tokenizer { Next(); // todo: add error handling around bad indices auto index = ParseNumber(true, false, str_index.str()).ToNumber(); - int end_line = this->line; - int end_column = this->col; - auto span = Span(this->source_name, line, column, end_line, end_column); - return Token(span, TokenType::MetaReference, MetaRef(type_key.str(), index)); + auto span = SpanFrom(line, column); + return Token(span, TokenType::kMetaReference, MetaRef(type_key.str(), index)); } Token TokenizeAttr() { @@ -266,14 +264,14 @@ struct Tokenizer { } ObjectRef metadata_map = tvm::LoadJSON(metadata.str()); auto span = SpanFrom(line, column); - return Token(span, TokenType::Metadata, metadata_map); + return Token(span, TokenType::kMetadata, metadata_map); } if (attribute.rfind("version", 0) == 0) { std::string version = attribute.substr(attribute.find("=") + 1); ltrim(version); rtrim(version); auto span = SpanFrom(line, column); - return Token(span, TokenType::Version, tvm::String(version)); + return Token(span, TokenType::kVersion, tvm::String(version)); } else { // TOOD(@jroesch): maybe make this a warning an continue parsing? auto span = SpanFrom(line, column); @@ -296,13 +294,13 @@ struct Tokenizer { auto next = Peek(); DLOG(INFO) << "tvm::parser::TokenizeOnce: next=" << next; if (next == '\n') { - auto token = NewToken(TokenType::Newline); + auto token = NewToken(TokenType::kNewline); Next(); return token; } else if (next == '\r') { Next(); if (More() && Peek() == '\n') { - auto token = NewToken(TokenType::Newline); + auto token = NewToken(TokenType::kNewline); return token; } else { auto span = SpanFrom(line, col); @@ -320,9 +318,9 @@ struct Tokenizer { string_content << Next(); } Next(); - return NewToken(TokenType::StringLiteral, tvm::String(string_content.str())); + return NewToken(TokenType::kStringLiteral, tvm::String(string_content.str())); } else if (IsWhitespace(next)) { - auto token = NewToken(TokenType::Whitespace); + auto token = NewToken(TokenType::kWhitespace); Next(); return token; } else if (IsDigit(next) || next == '-') { @@ -336,7 +334,7 @@ struct Tokenizer { // with multi-token return or something. if (negs && !IsDigit(Peek())) { pos = pos - (negs - 1); - return NewToken(TokenType::Minus); + return NewToken(TokenType::kMinus); } bool is_neg = negs % 2 == 1; @@ -354,79 +352,79 @@ struct Tokenizer { return ParseNumber(!is_neg, is_float, ss.str()); } else if (next == '.') { - auto token = NewToken(TokenType::Period); + auto token = NewToken(TokenType::kPeriod); Next(); return token; } else if (next == ',') { - auto token = NewToken(TokenType::Comma); + auto token = NewToken(TokenType::kComma); Next(); return token; } else if (next == '=') { - auto token = NewToken(TokenType::Equal); + auto token = NewToken(TokenType::kEqual); Next(); return token; } else if (next == ';') { - auto token = NewToken(TokenType::Semicolon); + auto token = NewToken(TokenType::kSemicolon); Next(); return token; } else if (next == ':') { - auto token = NewToken(TokenType::Colon); + auto token = NewToken(TokenType::kColon); Next(); return token; } else if (next == '(') { - auto token = NewToken(TokenType::OpenParen); + auto token = NewToken(TokenType::kOpenParen); Next(); return token; } else if (next == ')') { - auto token = NewToken(TokenType::CloseParen); + auto token = NewToken(TokenType::kCloseParen); Next(); return token; } else if (next == '+') { - auto token = NewToken(TokenType::Plus); + auto token = NewToken(TokenType::kPlus); Next(); return token; } else if (next == '-') { - auto token = NewToken(TokenType::Minus); + auto token = NewToken(TokenType::kMinus); Next(); return token; } else if (next == '*') { - auto token = NewToken(TokenType::Star); + auto token = NewToken(TokenType::kStar); Next(); return token; } else if (next == '<') { - auto token = NewToken(TokenType::LAngle); + auto token = NewToken(TokenType::kLAngle); Next(); return token; } else if (next == '>') { - auto token = NewToken(TokenType::RAngle); + auto token = NewToken(TokenType::kRAngle); Next(); return token; } else if (next == '{') { - auto token = NewToken(TokenType::LCurly); + auto token = NewToken(TokenType::kLCurly); Next(); return token; } else if (next == '}') { - auto token = NewToken(TokenType::RCurly); + auto token = NewToken(TokenType::kRCurly); Next(); return token; } else if (next == '[') { - auto token = NewToken(TokenType::LSquare); + auto token = NewToken(TokenType::kLSquare); Next(); return token; } else if (next == ']') { - auto token = NewToken(TokenType::RSquare); + auto token = NewToken(TokenType::kRSquare); Next(); return token; } else if (next == '!') { - auto token = NewToken(TokenType::Bang); + auto token = NewToken(TokenType::kBang); Next(); return token; } else if (next == '@') { - auto token = NewToken(TokenType::At); + auto token = NewToken(TokenType::kAt); Next(); return token; } else if (next == '?') { - auto token = NewToken(TokenType::Question); + auto token = NewToken(TokenType::kQuestion); Next(); return token; } else if (MatchString("meta")) { @@ -434,7 +432,7 @@ struct Tokenizer { } else if (next == '#') { return TokenizeAttr(); } else if (next == '%') { - auto token = NewToken(TokenType::Percent); + auto token = NewToken(TokenType::kPercent); Next(); std::stringstream number; @@ -446,14 +444,14 @@ struct Tokenizer { if (number_str.size()) { auto num_tok = ParseNumber(true, false, number_str); auto span = SpanFrom(token->span->line, token->span->column); - token = Token(span, TokenType::Graph, num_tok->data); + token = Token(span, TokenType::kGraph, num_tok->data); } return token; } else if (next == '/') { Next(); if (Peek() == '/') { - auto token = NewToken(TokenType::LineComment); + auto token = NewToken(TokenType::kLineComment); // Consume the / Next(); std::stringstream comment; @@ -467,10 +465,10 @@ struct Tokenizer { Next(); std::string comment; MatchComment(&comment); - auto token = NewToken(TokenType::Comment, tvm::String(comment)); + auto token = NewToken(TokenType::kComment, tvm::String(comment)); return token; } else { - return NewToken(TokenType::Division); + return NewToken(TokenType::kDivision); } } else if (IsIdentLetter(next)) { std::stringstream ss; @@ -491,14 +489,14 @@ struct Tokenizer { if (it != KEYWORD_TABLE.end()) { token_type = it->second; - if (token_type == TokenType::Match) { + if (token_type == TokenType::kMatch) { if (More() && Peek() == '?') { Next(); - token_type = TokenType::PartialMatch; + token_type = TokenType::kPartialMatch; } } } else { - token_type = TokenType::Identifier; + token_type = TokenType::kIdentifier; } auto span = SpanFrom(line, col); @@ -508,7 +506,7 @@ struct Tokenizer { while (More() && !IsWhitespace(Peek())) { ss << Next(); } - auto token = NewToken(TokenType::Unknown); + auto token = NewToken(TokenType::kUnknown); token->data = tvm::String(ss.str()); return token; } @@ -521,7 +519,7 @@ struct Tokenizer { CHECK(token.defined()); this->tokens.push_back(token); } - this->tokens.push_back(NewToken(TokenType::EndOfFile)); + this->tokens.push_back(NewToken(TokenType::kEndOfFile)); } explicit Tokenizer(DiagnosticContext* ctx, const SourceName& source_name, @@ -541,18 +539,18 @@ std::vector Condense(const std::vector& tokens) { for (size_t i = 0; i < tokens.size(); i++) { auto current = tokens.at(i); switch (current->token_type) { - case TokenType::Percent: { + case TokenType::kPercent: { auto next = tokens.at(i + 1); - if (next->token_type == TokenType::Identifier) { + if (next->token_type == TokenType::kIdentifier) { // Match this token. i += 1; // TODO(@jroesch): merge spans - auto tok = Token(current->span, TokenType::Local, next->data); + auto tok = Token(current->span, TokenType::kLocal, next->data); CHECK(tok.defined()); out.push_back(tok); - } else if (next->token_type == TokenType::Integer) { + } else if (next->token_type == TokenType::kInteger) { i += 1; - auto tok = Token(current->span, TokenType::Graph, next->data); + auto tok = Token(current->span, TokenType::kGraph, next->data); CHECK(tok.defined()); out.push_back(tok); } else { @@ -561,13 +559,13 @@ std::vector Condense(const std::vector& tokens) { } continue; } - case TokenType::At: { + case TokenType::kAt: { auto next = tokens.at(i + 1); - if (next->token_type == TokenType::Identifier) { + if (next->token_type == TokenType::kIdentifier) { // Match this token. i += 1; // TODO(@jroesch): merge spans - auto tok = Token(current->span, TokenType::Global, next->data); + auto tok = Token(current->span, TokenType::kGlobal, next->data); CHECK(tok.defined()); out.push_back(tok); } else { @@ -576,18 +574,18 @@ std::vector Condense(const std::vector& tokens) { } continue; } - case TokenType::Identifier: { + case TokenType::kIdentifier: { std::string str = Downcast(current->data); Token tok; // TODO(@jroesch): merge spans if (str == "True") { auto data = tvm::Integer(1); - tok = Token(current->span, TokenType::Boolean, data); + tok = Token(current->span, TokenType::kBoolean, data); } else if (str == "False") { auto data = tvm::Integer(0); - tok = Token(current->span, TokenType::Boolean, data); + tok = Token(current->span, TokenType::kBoolean, data); } else if (str == "_") { - tok = Token(current->span, TokenType::Underscore); + tok = Token(current->span, TokenType::kUnderscore); } else { tok = current; } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 90cf428f1ca1..1b09052a63d8 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -275,7 +275,7 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); + doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine(); // Memoization is done in AllocVar. return memo_[expr]; } else if (inline_expr) { diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index c01457745e97..b65b03c38063 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -365,7 +365,8 @@ class TextPrinter { /*! \brief whether show meta data */ bool show_meta_data_; - /*! \brief whether show meta data */ + + /*! \brief whether show the meta data warning message */ bool show_warning_; /*! \brief meta data context */ diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 2a88c0c99ae7..cd677096b2e9 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -22,20 +22,22 @@ from tvm.relay import Expr from tvm.relay.analysis import free_vars -do_print = [False] +do_print = [True] -SEMVER = "v0.0.4\n" +SEMVER = "#[version = \"0.0.4\"]\n" -def astext(p, unify_free_vars=False): - txt = p.astext() - if isinstance(p, Expr) and free_vars(p): - return txt - x = relay.fromtext(txt) - if unify_free_vars: - tvm.ir.assert_structural_equal(x, p, map_free_vars=True) +def astext(program, unify_free_vars=False): + text = program.astext() + print(text) + + if isinstance(program, Expr): + roundtrip_program = tvm.parser.parse_expr(text) else: - tvm.ir.assert_structural_equal(x, p) - return txt + roundtrip_program = tvm.parser.fromtext(text) + + tvm.ir.assert_structural_equal(roundtrip_program, program, map_free_vars=True) + + return text def show(text): if do_print[0]: @@ -252,23 +254,23 @@ def test_null_attribute(): if __name__ == "__main__": do_print[0] = True test_lstm() - test_zeros() - test_meta_data() - test_let_inlining() - test_resnet() - test_mobilenet() - test_mlp() - test_dqn() - test_dcgan() - test_squeezenet() - test_inception_v3() - test_vgg() - test_densenet() - test_func() - test_env() - test_call_attrs() - test_let_if_scope() - test_variable_name() - test_call_node_order() - test_unapplied_constructor() - test_null_attribute() + # test_zeros() + # test_meta_data() + # test_let_inlining() + # test_resnet() + # test_mobilenet() + # test_mlp() + # test_dqn() + # test_dcgan() + # test_squeezenet() + # test_inception_v3() + # test_vgg() + # test_densenet() + # test_func() + # test_env() + # test_call_attrs() + # test_let_if_scope() + # test_variable_name() + # test_call_node_order() + # test_unapplied_constructor() + # test_null_attribute()