Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] parser/pretty printer roundtripping #3536

Merged
merged 3 commits into from
Jul 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 168 additions & 52 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name, unused-import
# pylint: disable=invalid-name, unused-argument
"""A parser for Relay's text format."""
from __future__ import absolute_import

import sys
from ast import literal_eval

from collections import deque
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict

import tvm

Expand All @@ -32,6 +32,23 @@
from . import ty
from . import op

PYTHON_VERSION = sys.version_info.major
try:
from .grammar.py3.RelayVisitor import RelayVisitor
from .grammar.py3.RelayParser import RelayParser
from .grammar.py3.RelayLexer import RelayLexer
except ImportError:
raise Exeption("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")

try:
from antlr4 import InputStream, CommonTokenStream
from antlr4.error.ErrorListener import ErrorListener
except ImportError:
raise Exception("Couldn't find ANTLR runtime." +
"Try running `pip{version} install antlr4-python{version}-runtime`."
.format(version=PYTHON_VERSION))

sys.setrecursionlimit(10000)

class ParseError(Exception):
"""Exception type for parse errors."""
Expand All @@ -41,21 +58,50 @@ def __init__(self, message):
super(ParseError, self).__init__()
self.message = message

PYTHON_VERSION = sys.version_info.major
try:
from .grammar.py3.RelayVisitor import RelayVisitor
from .grammar.py3.RelayParser import RelayParser
from .grammar.py3.RelayLexer import RelayLexer
except ImportError:
raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")
def __repr__(self):
return "ParseError({})".format(self.message)

try:
from antlr4 import ParserRuleContext, InputStream, CommonTokenStream
from antlr4.tree.Tree import TerminalNode
except ImportError:
raise ParseError("Couldn't find ANTLR runtime." +
"Try running `pip{version} install antlr4-python{version}-runtime`."
.format(version=PYTHON_VERSION))
def __str__(self):
return repr(self)

class OpWrapper:
"""Overload the __call__ for op."""
pass

class ExprOp(OpWrapper):
"""Call an expr. The default, but does not handle attrs well."""
def __init__(self, operator):
self.operator = operator

def __call__(self, args, attrs, type_args):
try:
return expr.Call(self.operator, args, attrs, type_args)
except Exception:
raise Exception(str(self.operator) + " " + str(attrs))

class FuncOp(OpWrapper):
"""Convert the attrs, call the python function with the attrs passed in as keyword arguments.
Tvm should provide this in the future, as this is pretty similar to what op.get is providing.
"""
def __init__(self, operator):
self.operator = operator

def convert(self, v):
if isinstance(v, tuple):
return tuple([self.convert(x) for x in v])
if isinstance(v, expr.Constant):
return v.data.asnumpy().item()
if isinstance(v, str):
return v
raise Exception(v)

def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if isinstance(x, expr.TupleWrapper):
x = x.astuple()
return x

BINARY_OPS = {
RelayParser.MUL: op.multiply,
Expand All @@ -70,16 +116,34 @@ def __init__(self, message):
RelayParser.NE: op.not_equal,
}

FUNC_OPS = {
"nn.conv2d": op.nn.conv2d,
"nn.batch_norm": op.nn.batch_norm,
"nn.dense": op.nn.dense,
"nn.bias_add": op.nn.bias_add,
"nn.max_pool2d": op.nn.max_pool2d,
"nn.global_max_pool2d": op.nn.global_max_pool2d,
"nn.avg_pool2d": op.nn.avg_pool2d,
"nn.global_avg_pool2d": op.nn.global_avg_pool2d,
"nn.softmax": op.nn.softmax,
"reshape": op.reshape,
"nn.conv2d_transpose": op.nn.conv2d_transpose,
"concatenate": op.concatenate,
"nn.dropout": op.nn.dropout_raw,
"zeros": op.zeros,
"split": op.split,
}

TYPE_PREFIXES = [
"int",
"uint",
"float",
"bool",
]

T = TypeVar("T")
Scope = Deque[Tuple[str, T]]
Scopes = Deque[Scope[T]]
T = ty.TypeVar("T")
# Scope = Deque[Tuple[str, T]]
# Scopes = Deque[Scope[T]]

def lookup(scopes, name):
# type: (Scopes[T], str) -> Optional[T]
Expand Down Expand Up @@ -108,6 +172,8 @@ def _wrapper(*args, **kwargs):
ast = f(*args, **kwargs)
line, col = ctx.getSourceInterval()
sp = Span(sn, line, col)
if isinstance(ast, tvm.relay.expr.TupleWrapper):
ast = ast.astuple()
ast.set_span(sp)
return ast
return _wrapper
Expand Down Expand Up @@ -179,6 +245,9 @@ def mk_typ(self, name, kind):
self.type_param_scopes[0].appendleft((name, typ))
return typ

def visitProjection(self, ctx):
return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT()))

def visitTerminal(self, node):
# type: (TerminalNode) -> Union[expr.Expr, int, float]
"""Visit lexer tokens that aren't ignored or visited by other functions."""
Expand Down Expand Up @@ -213,12 +282,15 @@ def visitTerminal(self, node):
if node_text == "False":
return False
raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text))
if node_type == RelayLexer.QUOTED_STRING:
return literal_eval(node_text)

raise ParseError("todo: {}".format(node_text))
raise ParseError("todo: `{}`".format(node_text))

def visit_list(self, ctx_list):
# type: (List[ParserRuleContext]) -> List[Any]
""""Visit a list of contexts."""
assert isinstance(ctx_list, list)

return [self.visit(ctx) for ctx in ctx_list]

Expand All @@ -232,6 +304,11 @@ def getType_(self, ctx):
return self.visit(ctx)

def visitProg(self, ctx):
self.meta = None
if ctx.METADATA():
header, data = str(ctx.METADATA()).split('\n', 1)
assert header == "METADATA:"
self.meta = tvm.load_json(data)
# type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module]
if ctx.defn():
self.visit_list(ctx.defn())
Expand All @@ -245,11 +322,14 @@ def visitProg(self, ctx):
# Exprs
def visitOpIdent(self, ctx):
# type: (RelayParser.OpIdentContext) -> op.Op
return op.get(ctx.CNAME().getText())
op_name = ctx.CNAME().getText()
if op_name in FUNC_OPS:
return FuncOp(FUNC_OPS[op_name])
return ExprOp(op.get(op_name))

# pass through
def visitParens(self, ctx):
# type: (RelayParser.ParensContext) -> expr.Expr
def visitParen(self, ctx):
# type: (RelayParser.ParenContext) -> expr.Expr
return self.visit(ctx.expr())

# pass through
Expand Down Expand Up @@ -283,25 +363,17 @@ def visitTuple(self, ctx):
tup = self.visit_list(ctx.expr())
return expr.Tuple(tup)

# Currently doesn't support mutable sequencing.
def visitLet(self, ctx):
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes."""
if ctx.MUT() is not None:
raise ParseError("Mutation is currently unsupported.")

if ctx.var() is None or ctx.var().ident() is None:
if ctx.var() is None:
# anonymous identity
ident = "_"
type_ = None
var = self.mk_var(ident, type_)
else:
local_var = ctx.var().ident().LOCAL_VAR()
if local_var is None:
raise ParseError("Only local ids may be used in `let`s.")
ident = local_var.getText()[1:]
type_ = self.getType_(ctx.var().type_())

var = self.mk_var(ident, type_)
var = self.visitVar(ctx.var())

self.enter_var_scope()
value = self.visit(ctx.expr(0))
Expand All @@ -326,7 +398,7 @@ def visitBinOp(self, ctx):
def visitVar(self, ctx):
# type: (RelayParser.VarContext) -> expr.Var
"""Visit a single variable."""
ident = ctx.ident().LOCAL_VAR()
ident = ctx.LOCAL_VAR()

if ident is None:
raise ParseError("Only local ids may be used in vars.")
Expand All @@ -344,19 +416,29 @@ def visitAttr(self, ctx):
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
return (ctx.CNAME().getText(), self.visit(ctx.expr()))

def visitAttrList(self, ctx):
def visitArgNoAttr(self, ctx):
return (self.visit_list(ctx.varList().var()), None)

def visitAttrSeq(self, ctx):
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
return dict(self.visit_list(ctx.attr()))

def visitArgWithAttr(self, ctx):
return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq()))

def visitArgList(self,
ctx # type: RelayParser.ArgListContext
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
var_list = self.visit(ctx.varList()) if ctx.varList() else None
attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None

return (var_list, attr_list)

def visitMeta(self, ctx):
type_key = str(ctx.CNAME())
index = int(self.visit(ctx.NAT()))
return self.meta[type_key][index]

def mk_func(self, ctx):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function
"""Construct a function from either a Func or Defn."""
Expand All @@ -365,7 +447,7 @@ def mk_func(self, ctx):
self.enter_var_scope()
# Capture type params in params.
self.enter_type_param_scope()
type_params = ctx.typeParamSeq()
type_params = ctx.typeParamList()

if type_params is not None:
type_params = type_params.ident()
Expand Down Expand Up @@ -405,18 +487,25 @@ def visitDefn(self, ctx):
raise ParseError("Only global ids may be used in `def`s.")
ident_name = ident.getText()[1:]
ident = self.mk_global_var(ident_name)

self.module[ident] = self.mk_func(ctx)

def visitCallNoAttr(self, ctx):
return (self.visit_list(ctx.exprList().expr()), None)

def visitCallWithAttr(self, ctx):
return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq()))

def call(self, func, args, attrs, type_args):
if isinstance(func, OpWrapper):
return func(args, attrs, type_args)
return expr.Call(func, args, attrs, type_args)

@spanify
def visitCall(self, ctx):
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs = self.visit_list(ctx.expr())

func = visited_exprs[0]
args = visited_exprs[1:]

return expr.Call(func, args, None, None)
func = self.visit(ctx.expr())
args, attrs = self.visit(ctx.callList())
return self.call(func, args, attrs, [])

@spanify
def visitIfElse(self, ctx):
Expand All @@ -438,9 +527,7 @@ def visitIfElse(self, ctx):
def visitGraph(self, ctx):
# type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment."""
if ctx.ident().GRAPH_VAR() is None:
raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText()))
graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:])
graph_nid = int(ctx.GRAPH_VAR().getText()[1:])

self.enter_var_scope()
value = self.visit(ctx.expr(0))
Expand Down Expand Up @@ -500,15 +587,18 @@ def visitParensShape(self, ctx):
# type: (RelayParser.ParensShapeContext) -> int
return self.visit(ctx.shape())

def visitShapeSeq(self, ctx):
# type: (RelayParser.ShapeSeqContext) -> List[int]
def visitShapeList(self, ctx):
# type: (RelayParser.ShapeListContext) -> List[int]
return self.visit_list(ctx.shape())

def visitTensor(self, ctx):
return tuple(self.visit_list(ctx.expr()))

def visitTensorType(self, ctx):
# type: (RelayParser.TensorTypeContext) -> ty.TensorType
"""Create a simple tensor type. No generics."""

shape = self.visit(ctx.shapeSeq())
shape = self.visit(ctx.shapeList())
dtype = self.visit(ctx.type_())

if not isinstance(dtype, ty.TensorType):
Expand Down Expand Up @@ -536,11 +626,37 @@ def make_parser(data):
"""Construct a RelayParser a given data stream."""
input_stream = InputStream(data)
lexer = RelayLexer(input_stream)
lexer.addErrorListener(StrictErrorListener(data))
token_stream = CommonTokenStream(lexer)
return RelayParser(token_stream)
p = RelayParser(token_stream)
p.addErrorListener(StrictErrorListener(data))
return p

__source_name_counter__ = 0

class StrictErrorListener(ErrorListener):
"""This ErrorListener fail eagerly on all error, and report the program."""
def __init__(self, text):
self.text = text

def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise Exception("Syntax Error in:\n" + self.text)

def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
raise Exception("Ambiguity Error in:\n" + self.text)

def reportAttemptingFullContext(self,
recognizer,
dfa,
startIndex,
stopIndex,
conflictingAlts,
configs):
raise Exception("Attempting Full Context in:\n" + self.text)

def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
raise Exception("Context Sensitivity in:\n" + self.text)

def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program."""
Expand Down
Loading