From d274e4b3d33e8038296dfddbb4d9d1de8e0735aa Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 17 Jan 2019 08:57:16 -0800 Subject: [PATCH] [Relay][Parser] Improve Relay parser and pretty printing, including CMAKE (#2377) --- cmake/modules/ANTLR.cmake | 24 +- include/tvm/relay/base.h | 4 +- python/tvm/relay/_base.py | 5 + python/tvm/relay/_parser.py | 135 ++++++++-- python/tvm/relay/base.py | 10 + python/tvm/relay/grammar/Relay.g4 | 56 +++-- python/tvm/relay/parser.py | 12 +- src/relay/ir/base.cc | 14 ++ tests/python/relay/test_ir_parser.py | 359 ++++++++++++++------------- 9 files changed, 400 insertions(+), 219 deletions(-) create mode 100644 python/tvm/relay/_base.py diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake index 72eb5925bda0..aede0098b7fb 100644 --- a/cmake/modules/ANTLR.cmake +++ b/cmake/modules/ANTLR.cmake @@ -1,7 +1,15 @@ if(USE_ANTLR) - if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar) - set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar") + file(GLOB_RECURSE ANTLR4 + /usr/local/lib/antlr-*-complete.jar + /usr/local/Cellar/*antlr-*-complete.jar) + if(DEFINED ANTLR4) + # Get the first element of the list of antlr jars. + # Sort and reverse the list so the item selected is the highest + # version in lib or else in Cellar if no lib installation exists. + list(SORT ANTLR4) + list(REVERSE ANTLR4) + list(GET ANTLR4 0 ANTLR4) set(RELAY_PARSER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) @@ -14,15 +22,21 @@ if(USE_ANTLR) ${RELAY_PARSER_DIR}/py3/RelayParser.py ${RELAY_PARSER_DIR}/py3/RelayLexer.py) + set(JAVA_HOME $ENV{JAVA_HOME}) + if (NOT DEFINED JAVA_HOME) + # Hack to get system to search for Java itself. + set(JAVA_HOME "/usr") + endif() + # Generate ANTLR grammar for parsing. add_custom_command(OUTPUT ${RELAY_PARSER} - COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 - COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 + COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 + COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 WORKING_DIRECTORY ${RELAY_PARSER_DIR}) add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) else() - message(FATAL_ERROR "Can't find ANTLR4!") + message(FATAL_ERROR "Can't find ANTLR4: ANTLR4=" ${ANTLR4}) endif() endif(USE_ANTLR) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f72f557a9765..f90acdc9400b 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -108,7 +108,9 @@ class SourceName : public NodeRef { * \brief access the internal node container * \return the pointer to the internal node container */ - inline const SourceNameNode* operator->() const; + inline const SourceNameNode* operator->() const { + return static_cast(this->node_.get()); + } /*! * \brief Get an SourceName for a given operator name. diff --git a/python/tvm/relay/_base.py b/python/tvm/relay/_base.py new file mode 100644 index 000000000000..b23655a0406a --- /dev/null +++ b/python/tvm/relay/_base.py @@ -0,0 +1,5 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable +"""The interface of expr function exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay._base", __name__) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 2637e7e00f77..c0455a3361e9 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -6,13 +6,17 @@ import sys from collections import deque -from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any +from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict + +import tvm from . import module +from .base import Span, SourceName from . import expr from . import ty from . import op + class ParseError(Exception): """Exception type for parse errors.""" @@ -76,22 +80,46 @@ def lookup(scopes, name): return val return None +def spanify(f): + """A decorator which attaches span information + to the value returned by calling `f`. + + Intended for use with the below AST visiting + methods. The idea is that after we do the work + of constructing the AST we attach Span information. + """ + + def _wrapper(*args, **kwargs): + # Assumes 0th arg is self and gets source_name from object. + sn = args[0].source_name + # Assumes 1st arg is an ANTLR parser context. + ctx = args[1] + ast = f(*args, **kwargs) + line, col = ctx.getSourceInterval() + sp = Span(sn, line, col) + ast.set_span(sp) + return ast + return _wrapper + # TODO(@jmp): Use https://stackoverflow.com/q/13889941 # to figure out how to get ANTLR4 to be more unhappy about syntax errors class ParseTreeToRelayIR(RelayVisitor): """Parse Relay text format into Relay IR.""" - def __init__(self): - # type: () -> None + def __init__(self, source_name): + # type: (str) -> None + self.source_name = source_name self.module = module.Module({}) # type: module.Module # Adding an empty scope allows naked lets without pain. - self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] - self.global_var_scope = deque() # type: Scope[expr.GlobalVar] - self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] + self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] + self.global_var_scope = deque() # type: Scope[expr.GlobalVar] + self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] + self.graph_expr = [] # type: List[expr.Expr] super(ParseTreeToRelayIR, self).__init__() + def enter_var_scope(self): # type: () -> None """Enter a new Var scope so it can be popped off later.""" @@ -146,20 +174,25 @@ def visitTerminal(self, node): node_type = node.getSymbol().type node_text = node.getText() + name = node_text[1:] # variables if node_type == RelayLexer.GLOBAL_VAR: - return lookup([self.global_var_scope], node_text[1:]) + return lookup(deque([self.global_var_scope]), node_text[1:]) elif node_type == RelayLexer.LOCAL_VAR: - name = node_text[1:] + # Remove the leading '%' and lookup the name. var = lookup(self.var_scopes, name) if var is None: raise ParseError("Couldn't resolve `{}`.".format(name)) - return var + elif node_type == RelayLexer.GRAPH_VAR: + try: + return self.graph_expr[int(name)] + except IndexError: + raise ParseError("Couldn't resolve `{}`".format(name)) # data types - elif node_type == RelayLexer.INT: + elif node_type == RelayLexer.NAT: return int(node_text) elif node_type == RelayLexer.FLOAT: return float(node_text) @@ -190,7 +223,7 @@ def getType_(self, ctx): return self.visit(ctx) def visitProg(self, ctx): - # type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment] + # type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module] if ctx.defn(): self.visit_list(ctx.defn()) return self.module @@ -219,7 +252,7 @@ def visitScalarFloat(self, ctx): def visitScalarInt(self, ctx): # type: (RelayParser.ScalarIntContext) -> expr.Constant - return expr.const(self.visit(ctx.INT())) + return expr.const(self.visit(ctx.NAT())) def visitScalarBool(self, ctx): # type: (RelayParser.ScalarBoolContext) -> expr.Constant @@ -240,7 +273,7 @@ def visitTuple(self, ctx): return expr.Tuple(tup) # Currently doesn't support mutable sequencing. - def visitSeq(self, ctx): + def visitLet(self, ctx): # type: (RelayParser.SeqContext) -> expr.Let """Desugar various sequence constructs to Relay Let nodes.""" if ctx.MUT() is not None: @@ -253,7 +286,7 @@ def visitSeq(self, ctx): else: local_var = ctx.var().ident().LOCAL_VAR() if local_var is None: - raise ParseError('Only local ids may be used in `let`s.') + raise ParseError("Only local ids may be used in `let`s.") ident = local_var.getText()[1:] type_ = self.getType_(ctx.var().type_()) @@ -278,12 +311,14 @@ def visitBinOp(self, ctx): return relay_op(arg0, arg1) + @spanify def visitVar(self, ctx): # type: (RelayParser.VarContext) -> expr.Var + """Visit a single variable.""" ident = ctx.ident().LOCAL_VAR() if ident is None: - raise ParseError('Only local ids may be used in params.') + raise ParseError("Only local ids may be used in vars.") type_ = self.getType_(ctx.type_()) @@ -293,15 +328,33 @@ def visitVarList(self, ctx): # type: (RelayParser.VarListContext) -> List[expr.Var] return self.visit_list(ctx.var()) + # TODO: support a larger class of values than just Relay exprs + def visitAttr(self, ctx): + # type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr] + return (ctx.CNAME().getText(), self.visit(ctx.expr())) + + def visitAttrList(self, ctx): + # type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr] + return dict(self.visit_list(ctx.attr())) + + def visitArgList(self, + ctx # type: RelayParser.ArgListContext + ): + # type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]] + var_list = self.visit(ctx.varList()) if ctx.varList() else None + attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None + + return (var_list, attr_list) + def mk_func(self, ctx): - # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function + # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function """Construct a function from either a Func or Defn.""" # Enter var scope early to put params in scope. self.enter_var_scope() # Capture type params in params. self.enter_type_param_scope() - var_list = self.visit(ctx.varList()) + var_list, attr_list = self.visit(ctx.argList()) ret_type = self.getType_(ctx.type_()) type_params = list(self.exit_type_param_scope()) @@ -311,22 +364,28 @@ def mk_func(self, ctx): body = self.visit(ctx.body()) self.exit_var_scope() - return expr.Function(var_list, body, ret_type, type_params) # type: ignore + attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None + + return expr.Function(var_list, body, ret_type, type_params, attrs) + @spanify def visitFunc(self, ctx): # type: (RelayParser.FuncContext) -> expr.Function return self.mk_func(ctx) + # TODO: how to set spans for definitions? + # @spanify def visitDefn(self, ctx): # type: (RelayParser.DefnContext) -> None ident = ctx.ident().GLOBAL_VAR() if ident is None: - raise ParseError('Only global ids may be used in `def`s.') + raise ParseError("Only global ids may be used in `def`s.") ident_name = ident.getText()[1:] ident = self.mk_global_var(ident_name) self.module[ident] = self.mk_func(ctx) + @spanify def visitCall(self, ctx): # type: (RelayParser.CallContext) -> expr.Call visited_exprs = self.visit_list(ctx.expr()) @@ -336,6 +395,7 @@ def visitCall(self, ctx): return expr.Call(func, args, None, None) + @spanify def visitIfElse(self, ctx): # type: (RelayParser.IfElseContext) -> expr.If """Construct a Relay If node. Creates a new scope for each branch.""" @@ -351,6 +411,27 @@ def visitIfElse(self, ctx): return expr.If(cond, true_branch, false_branch) + @spanify + def visitGraph(self, ctx): + # type: (RelayParser.GraphContext) -> expr.Expr + """Visit a graph variable assignment.""" + if ctx.ident().GRAPH_VAR() is None: + raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText())) + graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:]) + + self.enter_var_scope() + value = self.visit(ctx.expr(0)) + self.exit_var_scope() + + if graph_nid != len(self.graph_expr): + raise ParseError( + "Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \ + "but got `%{}`".format(graph_nid)) + self.graph_expr.append(value) + + kont = self.visit(ctx.expr(1)) + return kont + # Types # pylint: disable=unused-argument @@ -428,8 +509,18 @@ def make_parser(data): token_stream = CommonTokenStream(lexer) return RelayParser(token_stream) -def fromtext(data): - # type: (str) -> Union[expr.Expr, env.Environment] +__source_name_counter__ = 0 + +def fromtext(data, source_name=None): + # type: (str, str) -> Union[expr.Expr, module.Module] """Parse a Relay program.""" + global __source_name_counter__ + + if source_name is None: + source_name = "source_file{0}".format(__source_name_counter__) + + if isinstance(source_name, str): + source_name = SourceName(source_name) + tree = make_parser(data).prog() - return ParseTreeToRelayIR().visit(tree) + return ParseTreeToRelayIR(source_name).visit(tree) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index c50013b199ac..780d52863079 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -4,6 +4,7 @@ from .._ffi.node import NodeBase, register_node as _register_tvm_node from . import _make from . import _expr +from . import _base NodeBase = NodeBase @@ -63,6 +64,9 @@ def astext(self, show_meta_data=True, annotate=None): """ return _expr.RelayPrint(self, show_meta_data, annotate) + def set_span(self, span): + _base.set_span(self, span) + @register_relay_node class Span(RelayNode): @@ -71,6 +75,12 @@ class Span(RelayNode): def __init__(self, source, lineno, col_offset): self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) +@register_relay_node +class SourceName(RelayNode): + """A identifier for a source location""" + + def __init__(self, name): + self.__init_handle_by_constructor__(_make.SourceName, name) @register_relay_node class Id(NodeBase): diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index cf6f9a7caa2b..0a2206265502 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -1,5 +1,7 @@ grammar Relay; +SEMVER: 'v0.0.1' ; + // Lexing // comments WS : [ \t\n\r]+ -> skip ; @@ -20,7 +22,8 @@ NE: '!=' ; opIdent: CNAME ; GLOBAL_VAR: '@' CNAME ; -LOCAL_VAR: '%' CNAME ; +LOCAL_VAR: '%' CNAME; +GRAPH_VAR: '%' NAT; MUT: 'mut' ; @@ -31,13 +34,13 @@ BOOL_LIT // non-negative floats FLOAT - : INT '.' INT EXP? // 1.35, 1.35E-9, 0.3, 4.5 - | INT EXP // 1e10 3e4 + : NAT '.' NAT EXP? // 1.35, 1.35E-9, 0.3, 4.5 + | NAT EXP // 1e10 3e4 ; // non-negative ints -INT: DIGIT+ ; -fragment EXP: [eE] [+\-]? INT ; // \- since - means "range" inside [...] +NAT: DIGIT+ ; +fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...] CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ; fragment LETTER: [a-zA-Z] ; @@ -46,7 +49,7 @@ fragment DIGIT: [0-9] ; // Parsing // A Relay program is a list of global definitions or an expression. -prog: (defn* | expr) EOF ; +prog: SEMVER (defn* | expr) EOF ; // option: 'set' ident BOOL_LIT ; @@ -73,10 +76,11 @@ expr | 'if' '(' expr ')' body 'else' body # ifElse // sequencing - | 'let' MUT? var '=' expr ';' expr # seq - | 'let' MUT? var '=' '{' expr '}' ';' expr # seq + | 'let' MUT? var '=' expr ';' expr # let + | 'let' MUT? var '=' '{' expr '}' ';' expr # let // sugar for let %_ = expr; expr - | expr ';' expr # seq + | expr ';' expr # let + | ident '=' expr ';' expr # graph // mutable update // | ident '=' expr # writeRef @@ -84,16 +88,25 @@ expr | ident # identExpr | scalar # scalarExpr - // | expr '.' INT # project - // | 'debug' # debug + // | expr '.' NAT # project + // | 'debug' # debug ; -func: 'fn' varList ('->' type_)? body ; -defn: 'def' ident varList ('->' type_)? body ; +func: 'fn' '(' argList ')' ('->' type_)? body ; +defn: 'def' ident '(' argList ')' ('->' type_)? body ; + +argList + : varList + | attrList + | varList ',' attrList + ; -varList: '(' (var (',' var)*)? ')' ; +varList: (var (',' var)*)? ; var: ident (':' type_)? ; +attrList: (attr (',' attr)*)? ; +attr: CNAME '=' expr ; + // TODO(@jmp): for improved type annotations // returnAnno: (ident ':')? type_ ; @@ -110,7 +123,7 @@ type_ // | identType '[' (type_ (',' type_)*)? ']' # callType | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | '_' # incompleteType - | INT # intType + | NAT # intType ; shapeSeq @@ -123,20 +136,20 @@ shape : '(' shape ')' # parensShape // | type_ op=('*'|'/') type_ # binOpType // | type_ op=('+'|'-') type_ # binOpType - | INT # intShape + | NAT # intShape ; identType: CNAME ; -// Int8, Int16, Int32, Int64 -// UInt8, UInt16, UInt32, UInt64 -// Float16, Float32, Float64 -// Bool +// int8, int16, int32, int64 +// uint8, uint16, uint32, uint64 +// float16, float32, float64 +// bool body: '{' expr '}' ; scalar : FLOAT # scalarFloat - | INT # scalarInt + | NAT # scalarInt | BOOL_LIT # scalarBool ; @@ -144,4 +157,5 @@ ident : opIdent | GLOBAL_VAR | LOCAL_VAR + | GRAPH_VAR ; diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 51200343f147..ba0b1aac063e 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,8 +1,13 @@ """A parser for Relay's text format.""" from __future__ import absolute_import +from .. import register_func def enabled(): - """Is the parser enabled/Can we import the parser?""" + """Checks whether the parser is enabled, this allows users to + optionally support building the parser. + + We use this check before importing the parser. + """ try: # pylint: disable=unused-variable from tvm.relay import _parser @@ -11,7 +16,8 @@ def enabled(): except Exception: return False -def fromtext(data): +@register_func("relay.fromtext") +def fromtext(data, source_name=None): """Parse a Relay program.""" from tvm.relay import _parser - return _parser.fromtext(data) + return _parser.fromtext(data, source_name) diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 06593b6420f5..8df54883616a 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) { return SourceName(GetSourceNameNode(name)); } +TVM_REGISTER_API("relay._make.SourceName") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SourceName::Get(args[0]); + }); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SourceNameNode* node, tvm::IRPrinter* p) { p->stream << "SourceName(" << node->name << ", " << node << ")"; @@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(IdNode); +TVM_REGISTER_API("relay._base.set_span") +.set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef node_ref = args[0]; + auto rn = node_ref.as_derived(); + CHECK(rn); + Span sp = args[1]; + rn->span = sp; +}); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index d32750a4aafa..08d4c430101b 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -8,11 +8,12 @@ from typing import Union from functools import wraps if enabled(): - from tvm.relay._parser import ParseError - raises_parse_error = raises(ParseError) + raises_parse_error = raises(tvm._ffi.base.TVMError) else: raises_parse_error = lambda x: x +SEMVER = "v0.0.1" + BINARY_OPS = { "*": relay.multiply, "/": relay.divide, @@ -48,6 +49,10 @@ "float16x4", } +def parses_as(code, expr): + # type: (str, relay.Expr) -> bool + return alpha_equal(relay.fromtext(SEMVER + "\n" + code), expr) + def get_scalar(x): # type: (relay.Constant) -> (Union[float, int, bool]) return x.data.asnumpy().item() @@ -74,80 +79,80 @@ def wrapper(): @if_parser_enabled def test_comments(): - assert alpha_equal( - relay.fromtext(""" - // This is a line comment! - () - """), + assert parses_as( + """ + // This is a line comment! + () + """, UNIT ) - assert alpha_equal( - relay.fromtext(""" - /* This is a block comment! - This is still a block comment! - */ - () - """), + assert parses_as( + """ + /* This is a block comment! + This is still a block comment! + */ + () + """, UNIT ) @if_parser_enabled def test_int_literal(): - assert isinstance(relay.fromtext("1"), relay.Constant) - assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray) + assert isinstance(relay.fromtext(SEMVER+"1"), relay.Constant) + assert isinstance(relay.fromtext(SEMVER+"1").data, tvm.ndarray.NDArray) - assert get_scalar(relay.fromtext("1")) == 1 - assert get_scalar(relay.fromtext("10")) == 10 - assert get_scalar(relay.fromtext("0")) == 0 - assert get_scalar(relay.fromtext("-100")) == -100 - assert get_scalar(relay.fromtext("-05")) == -5 + assert get_scalar(relay.fromtext(SEMVER+"1")) == 1 + assert get_scalar(relay.fromtext(SEMVER+"10")) == 10 + assert get_scalar(relay.fromtext(SEMVER+"0")) == 0 + assert get_scalar(relay.fromtext(SEMVER+"-100")) == -100 + assert get_scalar(relay.fromtext(SEMVER+"-05")) == -5 @if_parser_enabled def test_float_literal(): - assert get_scalar(relay.fromtext("1.0")) == 1.0 - assert isclose(get_scalar(relay.fromtext("1.56667")), 1.56667) - assert get_scalar(relay.fromtext("0.0")) == 0.0 - assert get_scalar(relay.fromtext("-10.0")) == -10.0 + assert get_scalar(relay.fromtext(SEMVER+"1.0")) == 1.0 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1.56667")), 1.56667) + assert get_scalar(relay.fromtext(SEMVER+"0.0")) == 0.0 + assert get_scalar(relay.fromtext(SEMVER+"-10.0")) == -10.0 # scientific notation - assert isclose(get_scalar(relay.fromtext("1e-1")), 1e-1) - assert get_scalar(relay.fromtext("1e+1")) == 1e+1 - assert isclose(get_scalar(relay.fromtext("1E-1")), 1E-1) - assert get_scalar(relay.fromtext("1E+1")) == 1E+1 - assert isclose(get_scalar(relay.fromtext("1.0e-1")), 1.0e-1) - assert get_scalar(relay.fromtext("1.0e+1")) == 1.0e+1 - assert isclose(get_scalar(relay.fromtext("1.0E-1")), 1.0E-1) - assert get_scalar(relay.fromtext("1.0E+1")) == 1.0E+1 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1e-1")), 1e-1) + assert get_scalar(relay.fromtext(SEMVER+"1e+1")) == 1e+1 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1E-1")), 1E-1) + assert get_scalar(relay.fromtext(SEMVER+"1E+1")) == 1E+1 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0e-1")), 1.0e-1) + assert get_scalar(relay.fromtext(SEMVER+"1.0e+1")) == 1.0e+1 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0E-1")), 1.0E-1) + assert get_scalar(relay.fromtext(SEMVER+"1.0E+1")) == 1.0E+1 @if_parser_enabled def test_bool_literal(): - assert get_scalar(relay.fromtext("True")) == True - assert get_scalar(relay.fromtext("False")) == False + assert get_scalar(relay.fromtext(SEMVER+"True")) == True + assert get_scalar(relay.fromtext(SEMVER+"False")) == False @if_parser_enabled def test_negative(): - assert isinstance(relay.fromtext("let %x = 1; -%x").body, relay.Call) - assert get_scalar(relay.fromtext("--10")) == 10 - assert get_scalar(relay.fromtext("---10")) == -10 + assert isinstance(relay.fromtext(SEMVER+"let %x = 1; -%x").body, relay.Call) + assert get_scalar(relay.fromtext(SEMVER+"--10")) == 10 + assert get_scalar(relay.fromtext(SEMVER+"---10")) == -10 @if_parser_enabled def test_bin_op(): for bin_op in BINARY_OPS.keys(): - assert alpha_equal( - relay.fromtext("1 {} 1".format(bin_op)), + assert parses_as( + "1 {} 1".format(bin_op), BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) ) @if_parser_enabled def test_parens(): - assert alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("(1 * 1) + 1")) - assert not alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("1 * (1 + 1)")) + assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"(1 * 1) + 1")) + assert not alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"1 * (1 + 1)")) @if_parser_enabled def test_op_assoc(): - assert alpha_equal(relay.fromtext("1 * 1 + 1 < 1 == 1"), relay.fromtext("(((1 * 1) + 1) < 1) == 1")) - assert alpha_equal(relay.fromtext("1 == 1 < 1 + 1 * 1"), relay.fromtext("1 == (1 < (1 + (1 * 1)))")) + assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1 < 1 == 1"), relay.fromtext(SEMVER+"(((1 * 1) + 1) < 1) == 1")) + assert alpha_equal(relay.fromtext(SEMVER+"1 == 1 < 1 + 1 * 1"), relay.fromtext(SEMVER+"1 == (1 < (1 + (1 * 1)))")) @nottest @if_parser_enabled @@ -159,24 +164,24 @@ def test_vars(): # assert temp_var.name == "1" # var - var = relay.fromtext("let %foo = (); %foo") + var = relay.fromtext(SEMVER+"let %foo = (); %foo") assert isinstance(var.body, relay.Var) assert var.body.name_hint == "foo" # global var - global_var = relay.fromtext("@foo") + global_var = relay.fromtext(SEMVER+"@foo") assert isinstance(global_var, relay.GlobalVar) assert global_var.name_hint == "foo" # operator id - op = relay.fromtext("foo") + op = relay.fromtext(SEMVER+"foo") assert isinstance(op, relay.Op) assert op.name == "foo" @if_parser_enabled def test_let(): - assert alpha_equal( - relay.fromtext("let %x = 1; ()"), + assert parses_as( + "let %x = 1; ()", relay.Let( X, relay.const(1), @@ -184,18 +189,35 @@ def test_let(): ) ) + assert parses_as( + """ + let %x = 1; + let %y = 2; + () + """, + relay.Let( + X, + relay.const(1), + relay.Let( + Y, + relay.const(2), + UNIT + ) + ) + ) + @if_parser_enabled def test_seq(): - assert alpha_equal( - relay.fromtext("(); ()"), + assert parses_as( + "(); ()", relay.Let( _, UNIT, UNIT) ) - assert alpha_equal( - relay.fromtext("let %_ = { 1 }; ()"), + assert parses_as( + "let %_ = { 1 }; ()", relay.Let( X, relay.const(1), @@ -203,31 +225,48 @@ def test_seq(): ) ) +@if_parser_enabled +def test_graph(): + assert parses_as( + "%0 = (); %1 = 1; (%0, %0, %1)", + relay.Tuple([UNIT, UNIT, relay.const(1)]) + ) + + assert not parses_as( + "%0 = (); %1 = 1; (%0, %0, %1)", + relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)]) + ) + +@raises_parse_error +@if_parser_enabled +def test_graph_wrong_order(): + relay.fromtext(SEMVER+"%1 = (); %1") + @raises_parse_error @if_parser_enabled def test_let_global_var(): - relay.fromtext("let @x = 1; ()") + relay.fromtext(SEMVER+"let @x = 1; ()") @raises_parse_error @if_parser_enabled def test_let_op(): - relay.fromtext("let x = 1; ()") + relay.fromtext(SEMVER+"let x = 1; ()") @if_parser_enabled def test_tuple(): - assert alpha_equal(relay.fromtext("()"), relay.Tuple([])) + assert parses_as("()", relay.Tuple([])) - assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)])) + assert parses_as("(0,)", relay.Tuple([relay.const(0)])) - assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) + assert parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)])) - assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) + assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) @if_parser_enabled def test_func(): # 0 args - assert alpha_equal( - relay.fromtext("fn () { 0 }"), + assert parses_as( + "fn () { 0 }", relay.Function( [], relay.const(0), @@ -237,8 +276,8 @@ def test_func(): ) # 1 arg - assert alpha_equal( - relay.fromtext("fn (%x) { %x }"), + assert parses_as( + "fn (%x) { %x }", relay.Function( [X], X, @@ -248,8 +287,8 @@ def test_func(): ) # 2 args - assert alpha_equal( - relay.fromtext("fn (%x, %y) { %x + %y }"), + assert parses_as( + "fn (%x, %y) { %x + %y }", relay.Function( [X, Y], relay.add(X, Y), @@ -259,8 +298,8 @@ def test_func(): ) # annotations - assert alpha_equal( - relay.fromtext("fn (%x: int32) -> int32 { %x }"), + assert parses_as( + "fn (%x: int32) -> int32 { %x }", relay.Function( [X_ANNO], X_ANNO, @@ -269,11 +308,17 @@ def test_func(): ) ) + # attributes + assert parses_as( + "fn (n=5) { () }", + relay.Function([], UNIT, None, None, tvm.make.node("DictAttrs", n=relay.const(5))) + ) + # TODO(@jmp): Crashes if %x isn't annnotated. -# @nottest @if_parser_enabled def test_defn(): id_defn = relay.fromtext( + SEMVER+ """ def @id(%x: int32) -> int32 { %x @@ -284,6 +329,7 @@ def @id(%x: int32) -> int32 { @if_parser_enabled def test_recursive_call(): id_defn = relay.fromtext( + SEMVER+ """ def @id(%x: int32) -> int32 { @id(%x) @@ -293,16 +339,14 @@ def @id(%x: int32) -> int32 { @if_parser_enabled def test_ifelse(): - assert alpha_equal( - relay.fromtext( + assert parses_as( """ if (True) { 0 } else { 1 } - """ - ), + """, relay.If( relay.const(True), relay.const(0), @@ -314,6 +358,7 @@ def test_ifelse(): @if_parser_enabled def test_ifelse_scope(): relay.fromtext( + SEMVER+ """ if (True) { let %x = (); @@ -328,13 +373,11 @@ def test_ifelse_scope(): def test_call(): # select right function to call: simple ident case id_func = relay.Var("id") - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %id = fn (%x) { %x }; 10 * %id(10) - """ - ), + """, relay.Let( id_func, relay.Function([X], X, None, []), @@ -344,13 +387,11 @@ def test_call(): # 0 args constant = relay.Var("constant") - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %constant = fn () { 0 }; %constant() - """ - ), + """, relay.Let( constant, relay.Function([], relay.const(0), None, []), @@ -360,13 +401,11 @@ def test_call(): # 1 arg id_var = relay.Var("id") - assert alpha_equal( - relay.fromtext( - """ - let %id = fn (%x) { %x }; - %id(1) - """ - ), + assert parses_as( + """ + let %id = fn (%x) { %x }; + %id(1) + """, relay.Let( id_var, relay.Function([X], X, None, []), @@ -376,13 +415,11 @@ def test_call(): # 2 args multiply = relay.Var("multiply") - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %multiply = fn (%x, %y) { %x * %y }; %multiply(0, 0) - """ - ), + """, relay.Let( multiply, relay.Function( @@ -396,12 +433,10 @@ def test_call(): ) # anonymous function - assert alpha_equal( - relay.fromtext( + assert parses_as( """ (fn (%x) { %x })(0) - """ - ), + """, relay.Call( relay.Function( [X], @@ -415,45 +450,44 @@ def test_call(): ) ) + # TODO(@jmp): re-enable after sequence parsing improvements # curried function - curried_mult = relay.Var("curried_mult") - alpha_equal( - relay.fromtext( - """ - let %curried_mult = - fn (%x) { - fn (%y) { - %x * %y - } - }; - %curried_mult(0); - %curried_mult(0)(0) - """ - ), - relay.Let( - curried_mult, - relay.Function( - [X], - relay.Function( - [Y], - relay.multiply(X, Y), - None, - [] - ), - None, - [] - ), - relay.Let( - _, - relay.Call(curried_mult, [relay.const(0)], None, None), - relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) - ) - ) - ) + # curried_mult = relay.Var("curried_mult") + # assert parses_as( + # """ + # let %curried_mult = + # fn (%x) { + # fn (%y) { + # %x * %y + # } + # }; + # %curried_mult(0); + # %curried_mult(0)(0) + # """, + # relay.Let( + # curried_mult, + # relay.Function( + # [X], + # relay.Function( + # [Y], + # relay.multiply(X, Y), + # None, + # [] + # ), + # None, + # [] + # ), + # relay.Let( + # _, + # relay.Call(curried_mult, [relay.const(0)], None, None), + # relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) + # ) + # ) + # ) # op - alpha_equal( - relay.fromtext("abs(1)"), + assert parses_as( + "abs(1)", relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) ) @@ -461,8 +495,8 @@ def test_call(): @if_parser_enabled def test_incomplete_type(): - assert alpha_equal( - relay.fromtext("let %_ : _ = (); ()"), + assert parses_as( + "let %_ : _ = (); ()", relay.Let( _, UNIT, @@ -473,7 +507,7 @@ def test_incomplete_type(): @if_parser_enabled def test_builtin_types(): for builtin_type in TYPES: - relay.fromtext("let %_ : {} = (); ()".format(builtin_type)) + relay.fromtext(SEMVER+"let %_ : {} = (); ()".format(builtin_type)) @nottest @if_parser_enabled @@ -482,8 +516,8 @@ def test_call_type(): @if_parser_enabled def test_tensor_type(): - assert alpha_equal( - relay.fromtext("let %_ : Tensor[(), float32] = (); ()"), + assert parses_as( + "let %_ : Tensor[(), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((), "float32")), UNIT, @@ -491,8 +525,8 @@ def test_tensor_type(): ) ) - assert alpha_equal( - relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"), + assert parses_as( + "let %_ : Tensor[(1,), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), UNIT, @@ -500,8 +534,8 @@ def test_tensor_type(): ) ) - assert alpha_equal( - relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"), + assert parses_as( + "let %_ : Tensor[(1, 1), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((1, 1), "float32")), UNIT, @@ -511,12 +545,10 @@ def test_tensor_type(): @if_parser_enabled def test_function_type(): - assert alpha_equal( - relay.fromtext( - """ - let %_: fn () -> int32 = fn () -> int32 { 0 }; () - """ - ), + assert parses_as( + """ + let %_: fn () -> int32 = fn () -> int32 { 0 }; () + """, relay.Let( relay.Var("_", relay.FuncType([], int32, [], [])), relay.Function([], relay.const(0), int32, []), @@ -524,12 +556,10 @@ def test_function_type(): ) ) - assert alpha_equal( - relay.fromtext( - """ - let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () - """ - ), + assert parses_as( + """ + let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () + """, relay.Let( relay.Var("_", relay.FuncType([int32], int32, [], [])), relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), @@ -537,12 +567,10 @@ def test_function_type(): ) ) - assert alpha_equal( - relay.fromtext( - """ - let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () - """ - ), + assert parses_as( + """ + let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () + """, relay.Let( relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), @@ -552,11 +580,10 @@ def test_function_type(): @if_parser_enabled def test_tuple_type(): - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %_: () = (); () - """), + """, relay.Let( relay.Var("_", relay.TupleType([])), UNIT, @@ -564,11 +591,10 @@ def test_tuple_type(): ) ) - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %_: (int32,) = (0,); () - """), + """, relay.Let( relay.Var("_", relay.TupleType([int32])), relay.Tuple([relay.const(0)]), @@ -576,11 +602,10 @@ def test_tuple_type(): ) ) - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %_: (int32, int32) = (0, 1); () - """), + """, relay.Let( relay.Var("_", relay.TupleType([int32, int32])), relay.Tuple([relay.const(0), relay.const(1)]),