From 8f81ea84a234bf06f5be4c01013da2fb71ab70cf Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Tue, 18 Jun 2019 19:36:51 -0700 Subject: [PATCH 1/3] commit --- {src => include/tvm}/common/base64.h | 0 include/tvm/json.h | 317 +++ python/tvm/relay/_parser.py | 214 ++- python/tvm/relay/analysis.py | 31 + python/tvm/relay/grammar/Relay.g4 | 125 +- python/tvm/relay/grammar/py3/RelayLexer.py | 315 +-- python/tvm/relay/grammar/py3/RelayParser.py | 1805 +++++++++++------- python/tvm/relay/grammar/py3/RelayVisitor.py | 89 +- python/tvm/relay/op/nn/nn.py | 46 +- python/tvm/relay/parser.py | 5 +- python/tvm/relay/testing/densenet.py | 2 +- python/tvm/relay/ty.py | 2 +- src/lang/reflection.cc | 282 +-- src/relay/ir/alpha_equal.cc | 105 +- src/relay/ir/doc.cc | 2 +- src/relay/ir/doc.h | 14 +- src/relay/ir/pretty_printer.cc | 189 +- tests/python/relay/test_ir_parser.py | 23 +- tests/python/relay/test_ir_text_printer.py | 80 +- 19 files changed, 2206 insertions(+), 1440 deletions(-) rename {src => include/tvm}/common/base64.h (100%) create mode 100644 include/tvm/json.h diff --git a/src/common/base64.h b/include/tvm/common/base64.h similarity index 100% rename from src/common/base64.h rename to include/tvm/common/base64.h diff --git a/include/tvm/json.h b/include/tvm/json.h new file mode 100644 index 000000000000..8f681f24abf9 --- /dev/null +++ b/include/tvm/json.h @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file json.h + * \brief A representation of JSON + */ + +#ifndef TVM_JSON_H_ +#define TVM_JSON_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { + +// use map so attributes are ordered. +using AttrMap = std::map; + +using runtime::Object; +using runtime::ObjectCell; + +inline std::string Type2String(const Type& t) { + return runtime::TVMType2String(Type2TVMType(t)); +} + +// indexer to index all the ndoes +class NodeIndexer : public AttrVisitor { + public: + std::unordered_map node_index{{nullptr, 0}}; + std::vector node_list{nullptr}; + std::unordered_map tensor_index; + std::vector tensor_list; + std::unordered_map vm_obj_index; + std::vector vm_obj_list; + + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, void** value) final {} + void Visit(const char* key, Type* value) final {} + void Visit(const char* key, NodeRef* value) final { + MakeIndex(value->node_.get()); + } + + void Visit(const char* key, runtime::NDArray* value) final { + DLTensor* ptr = const_cast((*value).operator->()); + if (tensor_index.count(ptr)) return; + CHECK_EQ(tensor_index.size(), tensor_list.size()); + tensor_index[ptr] = tensor_list.size(); + tensor_list.push_back(ptr); + } + + void Visit(const char* key, Object* value) final { + ObjectCell* ptr = value->ptr_.get(); + if (vm_obj_index.count(ptr)) return; + CHECK_EQ(vm_obj_index.size(), vm_obj_list.size()); + vm_obj_index[ptr] = vm_obj_list.size(); + vm_obj_list.push_back(ptr); + } + + // make index of all the children of node + void MakeIndex(Node* node) { + if (node == nullptr) return; + if (node_index.count(node)) return; + CHECK_EQ(node_index.size(), node_list.size()); + node_index[node] = node_list.size(); + node_list.push_back(node); + + if (node->is_type()) { + ArrayNode* n = static_cast(node); + for (const auto& sp : n->data) { + MakeIndex(sp.get()); + } + } else if (node->is_type()) { + MapNode* n = static_cast(node); + for (const auto& kv : n->data) { + MakeIndex(kv.first.get()); + MakeIndex(kv.second.get()); + } + } else if (node->is_type()) { + StrMapNode* n = static_cast(node); + for (const auto& kv : n->data) { + MakeIndex(kv.second.get()); + } + } else { + node->VisitAttrs(this); + } + } +}; + +// A Node structure for JSON node. +struct JSONNode { + // The type key of the data + std::string type_key; + // The global key for global object + std::string global_key; + // the attributes + AttrMap attrs; + // container keys + std::vector keys; + // container data + std::vector data; + + void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("type_key", type_key); + if (global_key.size() != 0) { + writer->WriteObjectKeyValue("global_key", global_key); + } + if (attrs.size() != 0) { + writer->WriteObjectKeyValue("attrs", attrs); + } + if (keys.size() != 0) { + writer->WriteObjectKeyValue("keys", keys); + } + if (data.size() != 0) { + writer->WriteObjectKeyValue("data", data); + } + writer->EndObject(); + } + + void Load(dmlc::JSONReader *reader) { + attrs.clear(); + data.clear(); + global_key.clear(); + type_key.clear(); + dmlc::JSONObjectReadHelper helper; + helper.DeclareOptionalField("type_key", &type_key); + helper.DeclareOptionalField("global_key", &global_key); + helper.DeclareOptionalField("attrs", &attrs); + helper.DeclareOptionalField("keys", &keys); + helper.DeclareOptionalField("data", &data); + helper.ReadAllFields(reader); + } +}; + +class JSONAttrGetter : public AttrVisitor { + public: + const std::unordered_map* node_index_; + const std::unordered_map* tensor_index_; + const std::unordered_map* vm_obj_index_; + JSONNode* node_; + + void Visit(const char* key, double* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, int64_t* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, uint64_t* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, int* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, bool* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, std::string* value) final { + node_->attrs[key] = *value; + } + void Visit(const char* key, void** value) final { + LOG(FATAL) << "not allowed to serialize a pointer"; + } + void Visit(const char* key, Type* value) final { + node_->attrs[key] = Type2String(*value); + } + void Visit(const char* key, NodeRef* value) final { + node_->attrs[key] = std::to_string( + node_index_->at(value->node_.get())); + } + void Visit(const char* key, runtime::NDArray* value) final { + node_->attrs[key] = std::to_string( + tensor_index_->at(const_cast((*value).operator->()))); + } + void Visit(const char* key, Object* value) final { + node_->attrs[key] = std::to_string( + vm_obj_index_->at(value->ptr_.get())); + } + // Get the node + void Get(Node* node) { + if (node == nullptr) { + node_->type_key.clear(); + return; + } + node_->type_key = node->type_key(); + // sepcially handle global object + auto* f = dmlc::Registry::Find(node_->type_key); + CHECK(f != nullptr) + << "Node type \'" << node_->type_key << "\' is not registered in TVM"; + if (f->fglobal_key != nullptr) { + node_->global_key = f->fglobal_key(node); + return; + } + node_->attrs.clear(); + node_->data.clear(); + if (node->is_type()) { + ArrayNode* n = static_cast(node); + for (size_t i = 0; i < n->data.size(); ++i) { + node_->data.push_back( + node_index_->at(n->data[i].get())); + } + } else if (node->is_type()) { + MapNode* n = static_cast(node); + for (const auto& kv : n->data) { + node_->data.push_back( + node_index_->at(kv.first.get())); + node_->data.push_back( + node_index_->at(kv.second.get())); + } + } else if (node->is_type()) { + StrMapNode* n = static_cast(node); + for (const auto& kv : n->data) { + node_->keys.push_back(kv.first); + node_->data.push_back( + node_index_->at(kv.second.get())); + } + } else { + // do not need to recover content of global singleton object + // they are registered via the environment + auto* f = dmlc::Registry::Find(node->type_key()); + if (f != nullptr && f->fglobal_key != nullptr) return; + // recursively index normal object. + node->VisitAttrs(this); + } + } +}; + +// json graph structure to store node +struct JSONGraph { + // the root of the graph + size_t root; + // the nodes of the graph + std::vector nodes; + // base64 b64ndarrays of arrays + std::vector b64ndarrays; + // global attributes + AttrMap attrs; + + void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("root", root); + writer->WriteObjectKeyValue("nodes", nodes); + writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays); + if (attrs.size() != 0) { + writer->WriteObjectKeyValue("attrs", attrs); + } + writer->EndObject(); + } + + void Load(dmlc::JSONReader *reader) { + attrs.clear(); + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("root", &root); + helper.DeclareField("nodes", &nodes); + helper.DeclareOptionalField("b64ndarrays", &b64ndarrays); + helper.DeclareOptionalField("attrs", &attrs); + helper.ReadAllFields(reader); + } + + static JSONGraph Create(const NodeRef& root) { + JSONGraph g; + NodeIndexer indexer; + indexer.MakeIndex(root.node_.get()); + JSONAttrGetter getter; + getter.node_index_ = &indexer.node_index; + getter.tensor_index_ = &indexer.tensor_index; + for (Node* n : indexer.node_list) { + JSONNode jnode; + getter.node_ = &jnode; + getter.Get(n); + g.nodes.emplace_back(std::move(jnode)); + } + g.attrs["tvm_version"] = TVM_VERSION; + g.root = indexer.node_index.at(root.node_.get()); + // serialize tensor + for (DLTensor* tensor : indexer.tensor_list) { + std::string blob; + dmlc::MemoryStringStream mstrm(&blob); + common::Base64OutStream b64strm(&mstrm); + runtime::SaveDLTensor(&b64strm, tensor); + b64strm.Finish(); + g.b64ndarrays.emplace_back(std::move(blob)); + } + return g; + } +}; + +} // namespace tvm +#endif // TVM_JSON_H_ diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index c483f4f75900..3c0077d4e078 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-import +# pylint: disable=invalid-name, unused-argument """A parser for Relay's text format.""" from __future__ import absolute_import import sys +from ast import literal_eval from collections import deque -from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict import tvm @@ -32,6 +32,23 @@ from . import ty from . import op +PYTHON_VERSION = sys.version_info.major +try: + from .grammar.py3.RelayVisitor import RelayVisitor + from .grammar.py3.RelayParser import RelayParser + from .grammar.py3.RelayLexer import RelayLexer +except ImportError: + raise Exeption("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") + +try: + from antlr4 import InputStream, CommonTokenStream + from antlr4.error.ErrorListener import ErrorListener +except ImportError: + raise Exception("Couldn't find ANTLR runtime." + + "Try running `pip{version} install antlr4-python{version}-runtime`." + .format(version=PYTHON_VERSION)) + +sys.setrecursionlimit(10000) class ParseError(Exception): """Exception type for parse errors.""" @@ -41,21 +58,50 @@ def __init__(self, message): super(ParseError, self).__init__() self.message = message -PYTHON_VERSION = sys.version_info.major -try: - from .grammar.py3.RelayVisitor import RelayVisitor - from .grammar.py3.RelayParser import RelayParser - from .grammar.py3.RelayLexer import RelayLexer -except ImportError: - raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") + def __repr__(self): + return "ParseError({})".format(self.message) -try: - from antlr4 import ParserRuleContext, InputStream, CommonTokenStream - from antlr4.tree.Tree import TerminalNode -except ImportError: - raise ParseError("Couldn't find ANTLR runtime." + - "Try running `pip{version} install antlr4-python{version}-runtime`." - .format(version=PYTHON_VERSION)) + def __str__(self): + return repr(self) + +class OpWrapper: + """Overload the __call__ for op.""" + pass + +class ExprOp(OpWrapper): + """Call an expr. The default, but does not handle attrs well.""" + def __init__(self, operator): + self.operator = operator + + def __call__(self, args, attrs, type_args): + try: + return expr.Call(self.operator, args, attrs, type_args) + except Exception: + raise Exception(str(self.operator) + " " + str(attrs)) + +class FuncOp(OpWrapper): + """Convert the attrs, call the python function with the attrs passed in as keyword arguments. + Tvm should provide this in the future, as this is pretty similar to what op.get is providing. + """ + def __init__(self, operator): + self.operator = operator + + def convert(self, v): + if isinstance(v, tuple): + return tuple([self.convert(x) for x in v]) + if isinstance(v, expr.Constant): + return v.data.asnumpy().item() + if isinstance(v, str): + return v + raise Exception(v) + + def __call__(self, args, attrs, type_args): + if attrs is None: + attrs = {} + x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()}) + if isinstance(x, expr.TupleWrapper): + x = x.astuple() + return x BINARY_OPS = { RelayParser.MUL: op.multiply, @@ -70,6 +116,24 @@ def __init__(self, message): RelayParser.NE: op.not_equal, } +FUNC_OPS = { + "nn.conv2d": op.nn.conv2d, + "nn.batch_norm": op.nn.batch_norm, + "nn.dense": op.nn.dense, + "nn.bias_add": op.nn.bias_add, + "nn.max_pool2d": op.nn.max_pool2d, + "nn.global_max_pool2d": op.nn.global_max_pool2d, + "nn.avg_pool2d": op.nn.avg_pool2d, + "nn.global_avg_pool2d": op.nn.global_avg_pool2d, + "nn.softmax": op.nn.softmax, + "reshape": op.reshape, + "nn.conv2d_transpose": op.nn.conv2d_transpose, + "concatenate": op.concatenate, + "nn.dropout": op.nn.dropout_raw, + "zeros": op.zeros, + "split": op.split, +} + TYPE_PREFIXES = [ "int", "uint", @@ -108,6 +172,8 @@ def _wrapper(*args, **kwargs): ast = f(*args, **kwargs) line, col = ctx.getSourceInterval() sp = Span(sn, line, col) + if isinstance(ast, tvm.relay.expr.TupleWrapper): + ast = ast.astuple() ast.set_span(sp) return ast return _wrapper @@ -179,6 +245,9 @@ def mk_typ(self, name, kind): self.type_param_scopes[0].appendleft((name, typ)) return typ + def visitProjection(self, ctx): + return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT())) + def visitTerminal(self, node): # type: (TerminalNode) -> Union[expr.Expr, int, float] """Visit lexer tokens that aren't ignored or visited by other functions.""" @@ -213,12 +282,15 @@ def visitTerminal(self, node): if node_text == "False": return False raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) + if node_type == RelayLexer.QUOTED_STRING: + return literal_eval(node_text) - raise ParseError("todo: {}".format(node_text)) + raise ParseError("todo: `{}`".format(node_text)) def visit_list(self, ctx_list): # type: (List[ParserRuleContext]) -> List[Any] """"Visit a list of contexts.""" + assert isinstance(ctx_list, list) return [self.visit(ctx) for ctx in ctx_list] @@ -232,6 +304,11 @@ def getType_(self, ctx): return self.visit(ctx) def visitProg(self, ctx): + self.meta = None + if ctx.METADATA(): + header, data = str(ctx.METADATA()).split('\n', 1) + assert header == "METADATA:" + self.meta = tvm.load_json(data) # type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module] if ctx.defn(): self.visit_list(ctx.defn()) @@ -245,11 +322,14 @@ def visitProg(self, ctx): # Exprs def visitOpIdent(self, ctx): # type: (RelayParser.OpIdentContext) -> op.Op - return op.get(ctx.CNAME().getText()) + op_name = ctx.CNAME().getText() + if op_name in FUNC_OPS: + return FuncOp(FUNC_OPS[op_name]) + return ExprOp(op.get(op_name)) # pass through - def visitParens(self, ctx): - # type: (RelayParser.ParensContext) -> expr.Expr + def visitParen(self, ctx): + # type: (RelayParser.ParenContext) -> expr.Expr return self.visit(ctx.expr()) # pass through @@ -283,25 +363,17 @@ def visitTuple(self, ctx): tup = self.visit_list(ctx.expr()) return expr.Tuple(tup) - # Currently doesn't support mutable sequencing. def visitLet(self, ctx): # type: (RelayParser.SeqContext) -> expr.Let """Desugar various sequence constructs to Relay Let nodes.""" - if ctx.MUT() is not None: - raise ParseError("Mutation is currently unsupported.") - if ctx.var() is None or ctx.var().ident() is None: + if ctx.var() is None: # anonymous identity ident = "_" type_ = None + var = self.mk_var(ident, type_) else: - local_var = ctx.var().ident().LOCAL_VAR() - if local_var is None: - raise ParseError("Only local ids may be used in `let`s.") - ident = local_var.getText()[1:] - type_ = self.getType_(ctx.var().type_()) - - var = self.mk_var(ident, type_) + var = self.visitVar(ctx.var()) self.enter_var_scope() value = self.visit(ctx.expr(0)) @@ -326,7 +398,7 @@ def visitBinOp(self, ctx): def visitVar(self, ctx): # type: (RelayParser.VarContext) -> expr.Var """Visit a single variable.""" - ident = ctx.ident().LOCAL_VAR() + ident = ctx.LOCAL_VAR() if ident is None: raise ParseError("Only local ids may be used in vars.") @@ -344,19 +416,29 @@ def visitAttr(self, ctx): # type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr] return (ctx.CNAME().getText(), self.visit(ctx.expr())) - def visitAttrList(self, ctx): + def visitArgNoAttr(self, ctx): + return (self.visit_list(ctx.varList().var()), None) + + def visitAttrSeq(self, ctx): # type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr] return dict(self.visit_list(ctx.attr())) + def visitArgWithAttr(self, ctx): + return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq())) + def visitArgList(self, ctx # type: RelayParser.ArgListContext ): # type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]] var_list = self.visit(ctx.varList()) if ctx.varList() else None attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None - return (var_list, attr_list) + def visitMeta(self, ctx): + type_key = str(ctx.CNAME()) + index = int(self.visit(ctx.NAT())) + return self.meta[type_key][index] + def mk_func(self, ctx): # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function """Construct a function from either a Func or Defn.""" @@ -365,7 +447,7 @@ def mk_func(self, ctx): self.enter_var_scope() # Capture type params in params. self.enter_type_param_scope() - type_params = ctx.typeParamSeq() + type_params = ctx.typeParamList() if type_params is not None: type_params = type_params.ident() @@ -405,18 +487,25 @@ def visitDefn(self, ctx): raise ParseError("Only global ids may be used in `def`s.") ident_name = ident.getText()[1:] ident = self.mk_global_var(ident_name) - self.module[ident] = self.mk_func(ctx) + def visitCallNoAttr(self, ctx): + return (self.visit_list(ctx.exprList().expr()), None) + + def visitCallWithAttr(self, ctx): + return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq())) + + def call(self, func, args, attrs, type_args): + if isinstance(func, OpWrapper): + return func(args, attrs, type_args) + return expr.Call(func, args, attrs, type_args) + @spanify def visitCall(self, ctx): # type: (RelayParser.CallContext) -> expr.Call - visited_exprs = self.visit_list(ctx.expr()) - - func = visited_exprs[0] - args = visited_exprs[1:] - - return expr.Call(func, args, None, None) + func = self.visit(ctx.expr()) + args, attrs = self.visit(ctx.callList()) + return self.call(func, args, attrs, []) @spanify def visitIfElse(self, ctx): @@ -438,9 +527,7 @@ def visitIfElse(self, ctx): def visitGraph(self, ctx): # type: (RelayParser.GraphContext) -> expr.Expr """Visit a graph variable assignment.""" - if ctx.ident().GRAPH_VAR() is None: - raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText())) - graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:]) + graph_nid = int(ctx.GRAPH_VAR().getText()[1:]) self.enter_var_scope() value = self.visit(ctx.expr(0)) @@ -500,15 +587,18 @@ def visitParensShape(self, ctx): # type: (RelayParser.ParensShapeContext) -> int return self.visit(ctx.shape()) - def visitShapeSeq(self, ctx): - # type: (RelayParser.ShapeSeqContext) -> List[int] + def visitShapeList(self, ctx): + # type: (RelayParser.ShapeListContext) -> List[int] return self.visit_list(ctx.shape()) + def visitTensor(self, ctx): + return tuple(self.visit_list(ctx.expr())) + def visitTensorType(self, ctx): # type: (RelayParser.TensorTypeContext) -> ty.TensorType """Create a simple tensor type. No generics.""" - shape = self.visit(ctx.shapeSeq()) + shape = self.visit(ctx.shapeList()) dtype = self.visit(ctx.type_()) if not isinstance(dtype, ty.TensorType): @@ -536,11 +626,37 @@ def make_parser(data): """Construct a RelayParser a given data stream.""" input_stream = InputStream(data) lexer = RelayLexer(input_stream) + lexer.addErrorListener(StrictErrorListener(data)) token_stream = CommonTokenStream(lexer) - return RelayParser(token_stream) + p = RelayParser(token_stream) + p.addErrorListener(StrictErrorListener(data)) + return p __source_name_counter__ = 0 +class StrictErrorListener(ErrorListener): + """This ErrorListener fail eagerly on all error, and report the program.""" + def __init__(self, text): + self.text = text + + def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): + raise Exception("Syntax Error in:\n" + self.text) + + def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs): + raise Exception("Ambiguity Error in:\n" + self.text) + + def reportAttemptingFullContext(self, + recognizer, + dfa, + startIndex, + stopIndex, + conflictingAlts, + configs): + raise Exception("Attempting Full Context in:\n" + self.text) + + def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs): + raise Exception("Context Sensitivity in:\n" + self.text) + def fromtext(data, source_name=None): # type: (str, str) -> Union[expr.Expr, module.Module] """Parse a Relay program.""" diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis.py index ee8ce985fcbc..91b53bb5f196 100644 --- a/python/tvm/relay/analysis.py +++ b/python/tvm/relay/analysis.py @@ -224,6 +224,20 @@ def alpha_equal(lhs, rhs): return bool(_make._alpha_equal(lhs, rhs)) +def assert_alpha_equal(lhs, rhs): + """Assert that two Relay expr is structurally equivalent. (alpha equivalence). + + Parameters + ---------- + lhs : tvm.relay.Expr + One of the input Expression. + + rhs : tvm.relay.Expr + One of the input Expression. + """ + _make._assert_alpha_equal(lhs, rhs) + + def graph_equal(lhs, rhs): """Compare two Relay expr for data-flow equivalence. The difference between this and alpha-equality is that @@ -246,6 +260,23 @@ def graph_equal(lhs, rhs): return bool(_make._graph_equal(lhs, rhs)) +def assert_graph_equal(lhs, rhs): + """Compare two Relay expr for data-flow equivalence. + The difference between this and alpha-equality is that + variables are not expected to match between lhs and rhs; + they are treated as sources and are mapped between each other. + + Parameters + ---------- + lhs : tvm.relay.Expr + One of the input Expression. + + rhs : tvm.relay.Expr + One of the input Expression. + """ + _make._assert_graph_equal(lhs, rhs) + + def collect_device_info(expr): """Collect the device allocation map for the given expression. The device ids are propagated from the `device_copy` operators. diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 916c4a6c378a..8830a4122e08 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -17,15 +17,20 @@ * under the License. */ +// list = *, seq = ? + grammar Relay; SEMVER: 'v0.0.3' ; // Lexing // comments -WS : [ \t\n\r]+ -> skip ; -LINE_COMMENT : '//' .*? '\n' -> skip ; -COMMENT : '/*' .*? '*/' -> skip ; +COMMENT : '/*' (COMMENT|.)*? '*/' -> skip; +WS : [ \t\n\r]+ -> skip; +LINE_COMMENT : '//' .*? '\n' -> skip; + +fragment ESCAPED_QUOTE : '\\"'; +QUOTED_STRING : '"' ( ESCAPED_QUOTE | ~('\n'|'\r') )*? '"'; // operators MUL: '*' ; @@ -39,18 +44,18 @@ GE: '>=' ; EQ: '==' ; NE: '!=' ; -opIdent: CNAME ; -GLOBAL_VAR: '@' CNAME ; -LOCAL_VAR: '%' CNAME; -GRAPH_VAR: '%' NAT; - -MUT: 'mut' ; - BOOL_LIT : 'True' | 'False' ; +CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)*; +opIdent: CNAME ; +GLOBAL_VAR: '@' CNAME ; +LOCAL_VAR: '%' CNAME; +GRAPH_VAR: '%' NAT; + +DATATYPE : 'int64'; // non-negative floats fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4 @@ -60,109 +65,99 @@ FLOAT : PREFLOAT 'f'; NAT: DIGIT+ ; fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...] -CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ; -fragment LETTER: [a-zA-Z] ; -fragment DIGIT: [0-9] ; +fragment LETTER: [a-zA-Z]; +fragment DIGIT: [0-9]; +METADATA: 'METADATA:' .*; // Parsing // A Relay program is a list of global definitions or an expression. -prog: SEMVER (defn* | expr) EOF ; +prog: SEMVER (defn* | expr) METADATA? EOF ; // option: 'set' ident BOOL_LIT ; +exprList: (expr (',' expr)*)?; +callList + : exprList # callNoAttr + | (expr ',')* attrSeq # callWithAttr + ; + expr // operators - : '(' expr ')' # parens + : '(' expr ')' # paren + | '{' expr '}' # paren // function application - | expr '(' (expr (',' expr)*)? ')' # call + | expr '(' callList ')' # call | '-' expr # neg | expr op=('*'|'/') expr # binOp | expr op=('+'|'-') expr # binOp | expr op=('<'|'>'|'<='|'>=') expr # binOp | expr op=('=='|'!=') expr # binOp - // function definition | func # funcExpr - // tuples and tensors | '(' ')' # tuple | '(' expr ',' ')' # tuple | '(' expr (',' expr)+ ')' # tuple + | expr '.' NAT # projection | '[' (expr (',' expr)*)? ']' # tensor - | 'if' '(' expr ')' body 'else' body # ifElse - // sequencing - | 'let' MUT? var '=' expr ';' expr # let - | 'let' MUT? var '=' '{' expr '}' ';' expr # let + | 'let' var '=' expr ';' expr # let // sugar for let %_ = expr; expr - | expr ';' expr # let - | ident '=' expr ';' expr # graph - - // mutable update - // | ident '=' expr # writeRef - // | expr '^' # readRef - + | expr ';;' expr # let + | GRAPH_VAR '=' expr ';' expr # graph | ident # identExpr | scalar # scalarExpr - // | expr '.' NAT # project - // | 'debug' # debug + | meta # metaExpr + | QUOTED_STRING # stringExpr ; -func: 'fn' typeParamSeq? '(' argList ')' ('->' type_)? body ; -defn: 'def' ident typeParamSeq? '(' argList ')' ('->' type_)? body ; +func: 'fn' typeParamList? '(' argList ')' ('->' type_)? body ; +defn: 'def' ident typeParamList? '(' argList ')' ('->' type_)? body ; argList - : varList - | attrList - | varList ',' attrList + : varList # argNoAttr + | (var ',')* attrSeq # argWithAttr ; -varList: (var (',' var)*)? ; -var: ident (':' type_)? ; +varList: (var (',' var)*)?; +var: LOCAL_VAR (':' type_)?; -attrList: (attr (',' attr)*)? ; +attrSeq: attr (',' attr)*; attr: CNAME '=' expr ; -// TODO(@jmp): for improved type annotations -// returnAnno: (ident ':')? type_ ; - -// relations: 'where' relation (',' relation)* ; -// relation: ident '(' (type_ (',' type_)*)? ')' ; - -typeParamSeq +typeParamList : '[' ']' | '[' ident (',' ident)* ']' ; type_ - : '(' ')' # tupleType - | '(' type_ ',' ')' # tupleType - | '(' type_ (',' type_)+ ')' # tupleType - | typeIdent # typeIdentType - | 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType - // currently unused - // | typeIdent '[' (type_ (',' type_)*)? ']' # callType - | 'fn' typeParamSeq? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType - | '_' # incompleteType - | NAT # intType + : '(' ')' # tupleType + | '(' type_ ',' ')' # tupleType + | '(' type_ (',' type_)+ ')' # tupleType + | typeIdent # typeIdentType + | 'Tensor' '[' shapeList ',' type_ ']' # tensorType + | 'fn' typeParamList? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType + | '_' # incompleteType + | NAT # intType ; -shapeSeq - : '(' ')' - | '(' shape ',' ')' - | '(' shape (',' shape)+ ')' +shapeList + : '(' shape (',' shape)+ ')' + | '(' ')' + | shape ; +meta : 'meta' '[' CNAME ']' '[' NAT ']'; + shape - : '(' shape ')' # parensShape - // | type_ op=('*'|'/') type_ # binOpType - // | type_ op=('+'|'-') type_ # binOpType - | NAT # intShape + : meta # metaShape + | '(' shape ')' # parensShape + | NAT # intShape ; -typeIdent : CNAME ; +typeIdent : CNAME; // int8, int16, int32, int64 // uint8, uint16, uint32, uint64 // float16, float32, float64 diff --git a/python/tvm/relay/grammar/py3/RelayLexer.py b/python/tvm/relay/grammar/py3/RelayLexer.py index 11c9c01b7f75..eec2e65d5666 100644 --- a/python/tvm/relay/grammar/py3/RelayLexer.py +++ b/python/tvm/relay/grammar/py3/RelayLexer.py @@ -7,116 +7,147 @@ def serializedATN(): with StringIO() as buf: - buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2*") - buf.write("\u010d\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7") + buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2/") + buf.write("\u014a\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7") buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r") buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23") buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30") buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36") buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%") - buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\3\2") - buf.write("\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\7\3\b\3") - buf.write("\b\3\b\3\b\3\b\3\t\3\t\3\t\3\t\3\n\3\n\3\13\3\13\3\f\3") - buf.write("\f\3\r\3\r\3\16\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20") - buf.write("\3\20\3\21\3\21\3\22\3\22\3\22\3\22\3\22\3\22\3\22\3\23") - buf.write("\3\23\3\24\3\24\3\24\3\24\3\24\3\24\3\24\3\25\6\25\u0097") - buf.write("\n\25\r\25\16\25\u0098\3\25\3\25\3\26\3\26\3\26\3\26\7") - buf.write("\26\u00a1\n\26\f\26\16\26\u00a4\13\26\3\26\3\26\3\26\3") - buf.write("\26\3\27\3\27\3\27\3\27\7\27\u00ae\n\27\f\27\16\27\u00b1") - buf.write("\13\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\31\3\31\3") - buf.write("\32\3\32\3\33\3\33\3\34\3\34\3\35\3\35\3\36\3\36\3\36") - buf.write("\3\37\3\37\3\37\3 \3 \3 \3!\3!\3!\3\"\3\"\3\"\3#\3#\3") - buf.write("#\3$\3$\3$\3%\3%\3%\3%\3&\3&\3&\3&\3&\3&\3&\3&\3&\5&\u00e6") - buf.write("\n&\3\'\3\'\3\'\5\'\u00eb\n\'\3\'\5\'\u00ee\n\'\3(\3(") - buf.write("\3(\3)\6)\u00f4\n)\r)\16)\u00f5\3*\3*\5*\u00fa\n*\3*\3") - buf.write("*\3+\3+\5+\u0100\n+\3+\3+\3+\7+\u0105\n+\f+\16+\u0108") - buf.write("\13+\3,\3,\3-\3-\4\u00a2\u00af\2.\3\3\5\4\7\5\t\6\13\7") - buf.write("\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21") - buf.write("!\22#\23%\24\'\25)\26+\27-\30/\31\61\32\63\33\65\34\67") - buf.write("\359\36;\37= ?!A\"C#E$G%I&K\'M\2O(Q)S\2U*W\2Y\2\3\2\7") - buf.write("\5\2\13\f\17\17\"\"\4\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2") - buf.write("\u0114\2\3\3\2\2\2\2\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2") - buf.write("\2\13\3\2\2\2\2\r\3\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2") - buf.write("\23\3\2\2\2\2\25\3\2\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33") - buf.write("\3\2\2\2\2\35\3\2\2\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2") - buf.write("\2\2%\3\2\2\2\2\'\3\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2") - buf.write("\2\2\2/\3\2\2\2\2\61\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2") - buf.write("\2\67\3\2\2\2\29\3\2\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2") - buf.write("\2\2\2A\3\2\2\2\2C\3\2\2\2\2E\3\2\2\2\2G\3\2\2\2\2I\3") - buf.write("\2\2\2\2K\3\2\2\2\2O\3\2\2\2\2Q\3\2\2\2\2U\3\2\2\2\3[") - buf.write("\3\2\2\2\5]\3\2\2\2\7_\3\2\2\2\ta\3\2\2\2\13c\3\2\2\2") - buf.write("\re\3\2\2\2\17h\3\2\2\2\21m\3\2\2\2\23q\3\2\2\2\25s\3") - buf.write("\2\2\2\27u\3\2\2\2\31w\3\2\2\2\33y\3\2\2\2\35|\3\2\2\2") - buf.write("\37\177\3\2\2\2!\u0083\3\2\2\2#\u0085\3\2\2\2%\u008c\3") - buf.write("\2\2\2\'\u008e\3\2\2\2)\u0096\3\2\2\2+\u009c\3\2\2\2-") - buf.write("\u00a9\3\2\2\2/\u00b7\3\2\2\2\61\u00b9\3\2\2\2\63\u00bb") - buf.write("\3\2\2\2\65\u00bd\3\2\2\2\67\u00bf\3\2\2\29\u00c1\3\2") - buf.write("\2\2;\u00c3\3\2\2\2=\u00c6\3\2\2\2?\u00c9\3\2\2\2A\u00cc") - buf.write("\3\2\2\2C\u00cf\3\2\2\2E\u00d2\3\2\2\2G\u00d5\3\2\2\2") - buf.write("I\u00d8\3\2\2\2K\u00e5\3\2\2\2M\u00e7\3\2\2\2O\u00ef\3") - buf.write("\2\2\2Q\u00f3\3\2\2\2S\u00f7\3\2\2\2U\u00ff\3\2\2\2W\u0109") - buf.write("\3\2\2\2Y\u010b\3\2\2\2[\\\7*\2\2\\\4\3\2\2\2]^\7+\2\2") - buf.write("^\6\3\2\2\2_`\7.\2\2`\b\3\2\2\2ab\7]\2\2b\n\3\2\2\2cd") - buf.write("\7_\2\2d\f\3\2\2\2ef\7k\2\2fg\7h\2\2g\16\3\2\2\2hi\7g") - buf.write("\2\2ij\7n\2\2jk\7u\2\2kl\7g\2\2l\20\3\2\2\2mn\7n\2\2n") - buf.write("o\7g\2\2op\7v\2\2p\22\3\2\2\2qr\7?\2\2r\24\3\2\2\2st\7") - buf.write("=\2\2t\26\3\2\2\2uv\7}\2\2v\30\3\2\2\2wx\7\177\2\2x\32") - buf.write("\3\2\2\2yz\7h\2\2z{\7p\2\2{\34\3\2\2\2|}\7/\2\2}~\7@\2") - buf.write("\2~\36\3\2\2\2\177\u0080\7f\2\2\u0080\u0081\7g\2\2\u0081") - buf.write("\u0082\7h\2\2\u0082 \3\2\2\2\u0083\u0084\7<\2\2\u0084") - buf.write("\"\3\2\2\2\u0085\u0086\7V\2\2\u0086\u0087\7g\2\2\u0087") - buf.write("\u0088\7p\2\2\u0088\u0089\7u\2\2\u0089\u008a\7q\2\2\u008a") - buf.write("\u008b\7t\2\2\u008b$\3\2\2\2\u008c\u008d\7a\2\2\u008d") - buf.write("&\3\2\2\2\u008e\u008f\7x\2\2\u008f\u0090\7\62\2\2\u0090") - buf.write("\u0091\7\60\2\2\u0091\u0092\7\62\2\2\u0092\u0093\7\60") - buf.write("\2\2\u0093\u0094\7\65\2\2\u0094(\3\2\2\2\u0095\u0097\t") - buf.write("\2\2\2\u0096\u0095\3\2\2\2\u0097\u0098\3\2\2\2\u0098\u0096") - buf.write("\3\2\2\2\u0098\u0099\3\2\2\2\u0099\u009a\3\2\2\2\u009a") - buf.write("\u009b\b\25\2\2\u009b*\3\2\2\2\u009c\u009d\7\61\2\2\u009d") - buf.write("\u009e\7\61\2\2\u009e\u00a2\3\2\2\2\u009f\u00a1\13\2\2") - buf.write("\2\u00a0\u009f\3\2\2\2\u00a1\u00a4\3\2\2\2\u00a2\u00a3") - buf.write("\3\2\2\2\u00a2\u00a0\3\2\2\2\u00a3\u00a5\3\2\2\2\u00a4") - buf.write("\u00a2\3\2\2\2\u00a5\u00a6\7\f\2\2\u00a6\u00a7\3\2\2\2") - buf.write("\u00a7\u00a8\b\26\2\2\u00a8,\3\2\2\2\u00a9\u00aa\7\61") - buf.write("\2\2\u00aa\u00ab\7,\2\2\u00ab\u00af\3\2\2\2\u00ac\u00ae") - buf.write("\13\2\2\2\u00ad\u00ac\3\2\2\2\u00ae\u00b1\3\2\2\2\u00af") - buf.write("\u00b0\3\2\2\2\u00af\u00ad\3\2\2\2\u00b0\u00b2\3\2\2\2") - buf.write("\u00b1\u00af\3\2\2\2\u00b2\u00b3\7,\2\2\u00b3\u00b4\7") - buf.write("\61\2\2\u00b4\u00b5\3\2\2\2\u00b5\u00b6\b\27\2\2\u00b6") - buf.write(".\3\2\2\2\u00b7\u00b8\7,\2\2\u00b8\60\3\2\2\2\u00b9\u00ba") - buf.write("\7\61\2\2\u00ba\62\3\2\2\2\u00bb\u00bc\7-\2\2\u00bc\64") - buf.write("\3\2\2\2\u00bd\u00be\7/\2\2\u00be\66\3\2\2\2\u00bf\u00c0") - buf.write("\7>\2\2\u00c08\3\2\2\2\u00c1\u00c2\7@\2\2\u00c2:\3\2\2") - buf.write("\2\u00c3\u00c4\7>\2\2\u00c4\u00c5\7?\2\2\u00c5<\3\2\2") - buf.write("\2\u00c6\u00c7\7@\2\2\u00c7\u00c8\7?\2\2\u00c8>\3\2\2") - buf.write("\2\u00c9\u00ca\7?\2\2\u00ca\u00cb\7?\2\2\u00cb@\3\2\2") - buf.write("\2\u00cc\u00cd\7#\2\2\u00cd\u00ce\7?\2\2\u00ceB\3\2\2") - buf.write("\2\u00cf\u00d0\7B\2\2\u00d0\u00d1\5U+\2\u00d1D\3\2\2\2") - buf.write("\u00d2\u00d3\7\'\2\2\u00d3\u00d4\5U+\2\u00d4F\3\2\2\2") - buf.write("\u00d5\u00d6\7\'\2\2\u00d6\u00d7\5Q)\2\u00d7H\3\2\2\2") - buf.write("\u00d8\u00d9\7o\2\2\u00d9\u00da\7w\2\2\u00da\u00db\7v") - buf.write("\2\2\u00dbJ\3\2\2\2\u00dc\u00dd\7V\2\2\u00dd\u00de\7t") - buf.write("\2\2\u00de\u00df\7w\2\2\u00df\u00e6\7g\2\2\u00e0\u00e1") - buf.write("\7H\2\2\u00e1\u00e2\7c\2\2\u00e2\u00e3\7n\2\2\u00e3\u00e4") - buf.write("\7u\2\2\u00e4\u00e6\7g\2\2\u00e5\u00dc\3\2\2\2\u00e5\u00e0") - buf.write("\3\2\2\2\u00e6L\3\2\2\2\u00e7\u00ea\5Q)\2\u00e8\u00e9") - buf.write("\7\60\2\2\u00e9\u00eb\5Q)\2\u00ea\u00e8\3\2\2\2\u00ea") - buf.write("\u00eb\3\2\2\2\u00eb\u00ed\3\2\2\2\u00ec\u00ee\5S*\2\u00ed") - buf.write("\u00ec\3\2\2\2\u00ed\u00ee\3\2\2\2\u00eeN\3\2\2\2\u00ef") - buf.write("\u00f0\5M\'\2\u00f0\u00f1\7h\2\2\u00f1P\3\2\2\2\u00f2") - buf.write("\u00f4\5Y-\2\u00f3\u00f2\3\2\2\2\u00f4\u00f5\3\2\2\2\u00f5") - buf.write("\u00f3\3\2\2\2\u00f5\u00f6\3\2\2\2\u00f6R\3\2\2\2\u00f7") - buf.write("\u00f9\t\3\2\2\u00f8\u00fa\t\4\2\2\u00f9\u00f8\3\2\2\2") - buf.write("\u00f9\u00fa\3\2\2\2\u00fa\u00fb\3\2\2\2\u00fb\u00fc\5") - buf.write("Q)\2\u00fcT\3\2\2\2\u00fd\u0100\7a\2\2\u00fe\u0100\5W") - buf.write(",\2\u00ff\u00fd\3\2\2\2\u00ff\u00fe\3\2\2\2\u0100\u0106") - buf.write("\3\2\2\2\u0101\u0105\7a\2\2\u0102\u0105\5W,\2\u0103\u0105") - buf.write("\5Y-\2\u0104\u0101\3\2\2\2\u0104\u0102\3\2\2\2\u0104\u0103") - buf.write("\3\2\2\2\u0105\u0108\3\2\2\2\u0106\u0104\3\2\2\2\u0106") - buf.write("\u0107\3\2\2\2\u0107V\3\2\2\2\u0108\u0106\3\2\2\2\u0109") - buf.write("\u010a\t\5\2\2\u010aX\3\2\2\2\u010b\u010c\t\6\2\2\u010c") - buf.write("Z\3\2\2\2\16\2\u0098\u00a2\u00af\u00e5\u00ea\u00ed\u00f5") - buf.write("\u00f9\u00ff\u0104\u0106\3\b\2\2") + buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\4.") + buf.write("\t.\4/\t/\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\3\2") + buf.write("\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3") + buf.write("\t\3\t\3\n\3\n\3\n\3\13\3\13\3\13\3\13\3\13\3\f\3\f\3") + buf.write("\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20") + buf.write("\3\21\3\21\3\21\3\22\3\22\3\22\3\22\3\23\3\23\3\24\3\24") + buf.write("\3\24\3\24\3\24\3\24\3\24\3\25\3\25\3\26\3\26\3\26\3\26") + buf.write("\3\26\3\27\3\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\30") + buf.write("\3\30\3\30\7\30\u00b1\n\30\f\30\16\30\u00b4\13\30\3\30") + buf.write("\3\30\3\30\3\30\3\30\3\31\6\31\u00bc\n\31\r\31\16\31\u00bd") + buf.write("\3\31\3\31\3\32\3\32\3\32\3\32\7\32\u00c6\n\32\f\32\16") + buf.write("\32\u00c9\13\32\3\32\3\32\3\32\3\32\3\33\3\33\3\33\3\34") + buf.write("\3\34\3\34\7\34\u00d5\n\34\f\34\16\34\u00d8\13\34\3\34") + buf.write("\3\34\3\35\3\35\3\36\3\36\3\37\3\37\3 \3 \3!\3!\3\"\3") + buf.write("\"\3#\3#\3#\3$\3$\3$\3%\3%\3%\3&\3&\3&\3\'\3\'\3\'\3\'") + buf.write("\3\'\3\'\3\'\3\'\3\'\5\'\u00fd\n\'\3(\3(\5(\u0101\n(\3") + buf.write("(\3(\3(\7(\u0106\n(\f(\16(\u0109\13(\3(\3(\7(\u010d\n") + buf.write("(\f(\16(\u0110\13(\3)\3)\3)\3*\3*\3*\3+\3+\3+\3,\3,\3") + buf.write(",\3,\3,\3,\3-\3-\3-\5-\u0124\n-\3-\5-\u0127\n-\3.\3.\3") + buf.write(".\3/\6/\u012d\n/\r/\16/\u012e\3\60\3\60\5\60\u0133\n\60") + buf.write("\3\60\3\60\3\61\3\61\3\62\3\62\3\63\3\63\3\63\3\63\3\63") + buf.write("\3\63\3\63\3\63\3\63\3\63\3\63\7\63\u0146\n\63\f\63\16") + buf.write("\63\u0149\13\63\5\u00b2\u00c7\u00d6\2\64\3\3\5\4\7\5\t") + buf.write("\6\13\7\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20") + buf.write("\37\21!\22#\23%\24\'\25)\26+\27-\30/\31\61\32\63\33\65") + buf.write("\2\67\349\35;\36=\37? A!C\"E#G$I%K&M\'O(Q)S*U+W,Y\2[-") + buf.write("]._\2a\2c\2e/\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4") + buf.write("\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u0155\2\3\3\2\2\2\2") + buf.write("\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3") + buf.write("\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3\2") + buf.write("\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3\2\2") + buf.write("\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2\2\'\3") + buf.write("\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2\2\2\2\61") + buf.write("\3\2\2\2\2\63\3\2\2\2\2\67\3\2\2\2\29\3\2\2\2\2;\3\2\2") + buf.write("\2\2=\3\2\2\2\2?\3\2\2\2\2A\3\2\2\2\2C\3\2\2\2\2E\3\2") + buf.write("\2\2\2G\3\2\2\2\2I\3\2\2\2\2K\3\2\2\2\2M\3\2\2\2\2O\3") + buf.write("\2\2\2\2Q\3\2\2\2\2S\3\2\2\2\2U\3\2\2\2\2W\3\2\2\2\2[") + buf.write("\3\2\2\2\2]\3\2\2\2\2e\3\2\2\2\3g\3\2\2\2\5i\3\2\2\2\7") + buf.write("k\3\2\2\2\tm\3\2\2\2\13o\3\2\2\2\rq\3\2\2\2\17s\3\2\2") + buf.write("\2\21u\3\2\2\2\23w\3\2\2\2\25z\3\2\2\2\27\177\3\2\2\2") + buf.write("\31\u0083\3\2\2\2\33\u0085\3\2\2\2\35\u0087\3\2\2\2\37") + buf.write("\u008a\3\2\2\2!\u008d\3\2\2\2#\u0090\3\2\2\2%\u0094\3") + buf.write("\2\2\2\'\u0096\3\2\2\2)\u009d\3\2\2\2+\u009f\3\2\2\2-") + buf.write("\u00a4\3\2\2\2/\u00ab\3\2\2\2\61\u00bb\3\2\2\2\63\u00c1") + buf.write("\3\2\2\2\65\u00ce\3\2\2\2\67\u00d1\3\2\2\29\u00db\3\2") + buf.write("\2\2;\u00dd\3\2\2\2=\u00df\3\2\2\2?\u00e1\3\2\2\2A\u00e3") + buf.write("\3\2\2\2C\u00e5\3\2\2\2E\u00e7\3\2\2\2G\u00ea\3\2\2\2") + buf.write("I\u00ed\3\2\2\2K\u00f0\3\2\2\2M\u00fc\3\2\2\2O\u0100\3") + buf.write("\2\2\2Q\u0111\3\2\2\2S\u0114\3\2\2\2U\u0117\3\2\2\2W\u011a") + buf.write("\3\2\2\2Y\u0120\3\2\2\2[\u0128\3\2\2\2]\u012c\3\2\2\2") + buf.write("_\u0130\3\2\2\2a\u0136\3\2\2\2c\u0138\3\2\2\2e\u013a\3") + buf.write("\2\2\2gh\7.\2\2h\4\3\2\2\2ij\7*\2\2j\6\3\2\2\2kl\7+\2") + buf.write("\2l\b\3\2\2\2mn\7}\2\2n\n\3\2\2\2op\7\177\2\2p\f\3\2\2") + buf.write("\2qr\7\60\2\2r\16\3\2\2\2st\7]\2\2t\20\3\2\2\2uv\7_\2") + buf.write("\2v\22\3\2\2\2wx\7k\2\2xy\7h\2\2y\24\3\2\2\2z{\7g\2\2") + buf.write("{|\7n\2\2|}\7u\2\2}~\7g\2\2~\26\3\2\2\2\177\u0080\7n\2") + buf.write("\2\u0080\u0081\7g\2\2\u0081\u0082\7v\2\2\u0082\30\3\2") + buf.write("\2\2\u0083\u0084\7?\2\2\u0084\32\3\2\2\2\u0085\u0086\7") + buf.write("=\2\2\u0086\34\3\2\2\2\u0087\u0088\7=\2\2\u0088\u0089") + buf.write("\7=\2\2\u0089\36\3\2\2\2\u008a\u008b\7h\2\2\u008b\u008c") + buf.write("\7p\2\2\u008c \3\2\2\2\u008d\u008e\7/\2\2\u008e\u008f") + buf.write("\7@\2\2\u008f\"\3\2\2\2\u0090\u0091\7f\2\2\u0091\u0092") + buf.write("\7g\2\2\u0092\u0093\7h\2\2\u0093$\3\2\2\2\u0094\u0095") + buf.write("\7<\2\2\u0095&\3\2\2\2\u0096\u0097\7V\2\2\u0097\u0098") + buf.write("\7g\2\2\u0098\u0099\7p\2\2\u0099\u009a\7u\2\2\u009a\u009b") + buf.write("\7q\2\2\u009b\u009c\7t\2\2\u009c(\3\2\2\2\u009d\u009e") + buf.write("\7a\2\2\u009e*\3\2\2\2\u009f\u00a0\7o\2\2\u00a0\u00a1") + buf.write("\7g\2\2\u00a1\u00a2\7v\2\2\u00a2\u00a3\7c\2\2\u00a3,\3") + buf.write("\2\2\2\u00a4\u00a5\7x\2\2\u00a5\u00a6\7\62\2\2\u00a6\u00a7") + buf.write("\7\60\2\2\u00a7\u00a8\7\62\2\2\u00a8\u00a9\7\60\2\2\u00a9") + buf.write("\u00aa\7\65\2\2\u00aa.\3\2\2\2\u00ab\u00ac\7\61\2\2\u00ac") + buf.write("\u00ad\7,\2\2\u00ad\u00b2\3\2\2\2\u00ae\u00b1\5/\30\2") + buf.write("\u00af\u00b1\13\2\2\2\u00b0\u00ae\3\2\2\2\u00b0\u00af") + buf.write("\3\2\2\2\u00b1\u00b4\3\2\2\2\u00b2\u00b3\3\2\2\2\u00b2") + buf.write("\u00b0\3\2\2\2\u00b3\u00b5\3\2\2\2\u00b4\u00b2\3\2\2\2") + buf.write("\u00b5\u00b6\7,\2\2\u00b6\u00b7\7\61\2\2\u00b7\u00b8\3") + buf.write("\2\2\2\u00b8\u00b9\b\30\2\2\u00b9\60\3\2\2\2\u00ba\u00bc") + buf.write("\t\2\2\2\u00bb\u00ba\3\2\2\2\u00bc\u00bd\3\2\2\2\u00bd") + buf.write("\u00bb\3\2\2\2\u00bd\u00be\3\2\2\2\u00be\u00bf\3\2\2\2") + buf.write("\u00bf\u00c0\b\31\2\2\u00c0\62\3\2\2\2\u00c1\u00c2\7\61") + buf.write("\2\2\u00c2\u00c3\7\61\2\2\u00c3\u00c7\3\2\2\2\u00c4\u00c6") + buf.write("\13\2\2\2\u00c5\u00c4\3\2\2\2\u00c6\u00c9\3\2\2\2\u00c7") + buf.write("\u00c8\3\2\2\2\u00c7\u00c5\3\2\2\2\u00c8\u00ca\3\2\2\2") + buf.write("\u00c9\u00c7\3\2\2\2\u00ca\u00cb\7\f\2\2\u00cb\u00cc\3") + buf.write("\2\2\2\u00cc\u00cd\b\32\2\2\u00cd\64\3\2\2\2\u00ce\u00cf") + buf.write("\7^\2\2\u00cf\u00d0\7$\2\2\u00d0\66\3\2\2\2\u00d1\u00d6") + buf.write("\7$\2\2\u00d2\u00d5\5\65\33\2\u00d3\u00d5\n\3\2\2\u00d4") + buf.write("\u00d2\3\2\2\2\u00d4\u00d3\3\2\2\2\u00d5\u00d8\3\2\2\2") + buf.write("\u00d6\u00d7\3\2\2\2\u00d6\u00d4\3\2\2\2\u00d7\u00d9\3") + buf.write("\2\2\2\u00d8\u00d6\3\2\2\2\u00d9\u00da\7$\2\2\u00da8\3") + buf.write("\2\2\2\u00db\u00dc\7,\2\2\u00dc:\3\2\2\2\u00dd\u00de\7") + buf.write("\61\2\2\u00de<\3\2\2\2\u00df\u00e0\7-\2\2\u00e0>\3\2\2") + buf.write("\2\u00e1\u00e2\7/\2\2\u00e2@\3\2\2\2\u00e3\u00e4\7>\2") + buf.write("\2\u00e4B\3\2\2\2\u00e5\u00e6\7@\2\2\u00e6D\3\2\2\2\u00e7") + buf.write("\u00e8\7>\2\2\u00e8\u00e9\7?\2\2\u00e9F\3\2\2\2\u00ea") + buf.write("\u00eb\7@\2\2\u00eb\u00ec\7?\2\2\u00ecH\3\2\2\2\u00ed") + buf.write("\u00ee\7?\2\2\u00ee\u00ef\7?\2\2\u00efJ\3\2\2\2\u00f0") + buf.write("\u00f1\7#\2\2\u00f1\u00f2\7?\2\2\u00f2L\3\2\2\2\u00f3") + buf.write("\u00f4\7V\2\2\u00f4\u00f5\7t\2\2\u00f5\u00f6\7w\2\2\u00f6") + buf.write("\u00fd\7g\2\2\u00f7\u00f8\7H\2\2\u00f8\u00f9\7c\2\2\u00f9") + buf.write("\u00fa\7n\2\2\u00fa\u00fb\7u\2\2\u00fb\u00fd\7g\2\2\u00fc") + buf.write("\u00f3\3\2\2\2\u00fc\u00f7\3\2\2\2\u00fdN\3\2\2\2\u00fe") + buf.write("\u0101\7a\2\2\u00ff\u0101\5a\61\2\u0100\u00fe\3\2\2\2") + buf.write("\u0100\u00ff\3\2\2\2\u0101\u0107\3\2\2\2\u0102\u0106\7") + buf.write("a\2\2\u0103\u0106\5a\61\2\u0104\u0106\5c\62\2\u0105\u0102") + buf.write("\3\2\2\2\u0105\u0103\3\2\2\2\u0105\u0104\3\2\2\2\u0106") + buf.write("\u0109\3\2\2\2\u0107\u0105\3\2\2\2\u0107\u0108\3\2\2\2") + buf.write("\u0108\u010e\3\2\2\2\u0109\u0107\3\2\2\2\u010a\u010b\7") + buf.write("\60\2\2\u010b\u010d\5O(\2\u010c\u010a\3\2\2\2\u010d\u0110") + buf.write("\3\2\2\2\u010e\u010c\3\2\2\2\u010e\u010f\3\2\2\2\u010f") + buf.write("P\3\2\2\2\u0110\u010e\3\2\2\2\u0111\u0112\7B\2\2\u0112") + buf.write("\u0113\5O(\2\u0113R\3\2\2\2\u0114\u0115\7\'\2\2\u0115") + buf.write("\u0116\5O(\2\u0116T\3\2\2\2\u0117\u0118\7\'\2\2\u0118") + buf.write("\u0119\5]/\2\u0119V\3\2\2\2\u011a\u011b\7k\2\2\u011b\u011c") + buf.write("\7p\2\2\u011c\u011d\7v\2\2\u011d\u011e\78\2\2\u011e\u011f") + buf.write("\7\66\2\2\u011fX\3\2\2\2\u0120\u0123\5]/\2\u0121\u0122") + buf.write("\7\60\2\2\u0122\u0124\5]/\2\u0123\u0121\3\2\2\2\u0123") + buf.write("\u0124\3\2\2\2\u0124\u0126\3\2\2\2\u0125\u0127\5_\60\2") + buf.write("\u0126\u0125\3\2\2\2\u0126\u0127\3\2\2\2\u0127Z\3\2\2") + buf.write("\2\u0128\u0129\5Y-\2\u0129\u012a\7h\2\2\u012a\\\3\2\2") + buf.write("\2\u012b\u012d\5c\62\2\u012c\u012b\3\2\2\2\u012d\u012e") + buf.write("\3\2\2\2\u012e\u012c\3\2\2\2\u012e\u012f\3\2\2\2\u012f") + buf.write("^\3\2\2\2\u0130\u0132\t\4\2\2\u0131\u0133\t\5\2\2\u0132") + buf.write("\u0131\3\2\2\2\u0132\u0133\3\2\2\2\u0133\u0134\3\2\2\2") + buf.write("\u0134\u0135\5]/\2\u0135`\3\2\2\2\u0136\u0137\t\6\2\2") + buf.write("\u0137b\3\2\2\2\u0138\u0139\t\7\2\2\u0139d\3\2\2\2\u013a") + buf.write("\u013b\7O\2\2\u013b\u013c\7G\2\2\u013c\u013d\7V\2\2\u013d") + buf.write("\u013e\7C\2\2\u013e\u013f\7F\2\2\u013f\u0140\7C\2\2\u0140") + buf.write("\u0141\7V\2\2\u0141\u0142\7C\2\2\u0142\u0143\7<\2\2\u0143") + buf.write("\u0147\3\2\2\2\u0144\u0146\13\2\2\2\u0145\u0144\3\2\2") + buf.write("\2\u0146\u0149\3\2\2\2\u0147\u0145\3\2\2\2\u0147\u0148") + buf.write("\3\2\2\2\u0148f\3\2\2\2\u0149\u0147\3\2\2\2\23\2\u00b0") + buf.write("\u00b2\u00bd\u00c7\u00d4\u00d6\u00fc\u0100\u0105\u0107") + buf.write("\u010e\u0123\u0126\u012e\u0132\u0147\3\b\2\2") return buf.getvalue() @@ -144,51 +175,59 @@ class RelayLexer(Lexer): T__15 = 16 T__16 = 17 T__17 = 18 - SEMVER = 19 - WS = 20 - LINE_COMMENT = 21 - COMMENT = 22 - MUL = 23 - DIV = 24 - ADD = 25 - SUB = 26 - LT = 27 - GT = 28 - LE = 29 - GE = 30 - EQ = 31 - NE = 32 - GLOBAL_VAR = 33 - LOCAL_VAR = 34 - GRAPH_VAR = 35 - MUT = 36 + T__18 = 19 + T__19 = 20 + T__20 = 21 + SEMVER = 22 + COMMENT = 23 + WS = 24 + LINE_COMMENT = 25 + QUOTED_STRING = 26 + MUL = 27 + DIV = 28 + ADD = 29 + SUB = 30 + LT = 31 + GT = 32 + LE = 33 + GE = 34 + EQ = 35 + NE = 36 BOOL_LIT = 37 - FLOAT = 38 - NAT = 39 - CNAME = 40 + CNAME = 38 + GLOBAL_VAR = 39 + LOCAL_VAR = 40 + GRAPH_VAR = 41 + DATATYPE = 42 + FLOAT = 43 + NAT = 44 + METADATA = 45 channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] modeNames = [ "DEFAULT_MODE" ] literalNames = [ "", - "'('", "')'", "','", "'['", "']'", "'if'", "'else'", "'let'", - "'='", "';'", "'{'", "'}'", "'fn'", "'->'", "'def'", "':'", - "'Tensor'", "'_'", "'v0.0.3'", "'*'", "'/'", "'+'", "'-'", "'<'", - "'>'", "'<='", "'>='", "'=='", "'!='", "'mut'" ] + "','", "'('", "')'", "'{'", "'}'", "'.'", "'['", "']'", "'if'", + "'else'", "'let'", "'='", "';'", "';;'", "'fn'", "'->'", "'def'", + "':'", "'Tensor'", "'_'", "'meta'", "'v0.0.3'", "'*'", "'/'", + "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", "'!='", + "'int64'" ] symbolicNames = [ "", - "SEMVER", "WS", "LINE_COMMENT", "COMMENT", "MUL", "DIV", "ADD", - "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "GLOBAL_VAR", "LOCAL_VAR", - "GRAPH_VAR", "MUT", "BOOL_LIT", "FLOAT", "NAT", "CNAME" ] + "SEMVER", "COMMENT", "WS", "LINE_COMMENT", "QUOTED_STRING", + "MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", + "BOOL_LIT", "CNAME", "GLOBAL_VAR", "LOCAL_VAR", "GRAPH_VAR", + "DATATYPE", "FLOAT", "NAT", "METADATA" ] ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", - "T__14", "T__15", "T__16", "T__17", "SEMVER", "WS", "LINE_COMMENT", - "COMMENT", "MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", - "GE", "EQ", "NE", "GLOBAL_VAR", "LOCAL_VAR", "GRAPH_VAR", - "MUT", "BOOL_LIT", "PREFLOAT", "FLOAT", "NAT", "EXP", - "CNAME", "LETTER", "DIGIT" ] + "T__14", "T__15", "T__16", "T__17", "T__18", "T__19", + "T__20", "SEMVER", "COMMENT", "WS", "LINE_COMMENT", "ESCAPED_QUOTE", + "QUOTED_STRING", "MUL", "DIV", "ADD", "SUB", "LT", "GT", + "LE", "GE", "EQ", "NE", "BOOL_LIT", "CNAME", "GLOBAL_VAR", + "LOCAL_VAR", "GRAPH_VAR", "DATATYPE", "PREFLOAT", "FLOAT", + "NAT", "EXP", "LETTER", "DIGIT", "METADATA" ] grammarFileName = "Relay.g4" diff --git a/python/tvm/relay/grammar/py3/RelayParser.py b/python/tvm/relay/grammar/py3/RelayParser.py index b3c6238af8f2..923a731c3f5f 100644 --- a/python/tvm/relay/grammar/py3/RelayParser.py +++ b/python/tvm/relay/grammar/py3/RelayParser.py @@ -7,160 +7,173 @@ def serializedATN(): with StringIO() as buf: - buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3*") - buf.write("\u014c\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t\7") + buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3/") + buf.write("\u0164\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t\7") buf.write("\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r\4\16") buf.write("\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23\t\23") - buf.write("\3\2\3\2\3\3\3\3\7\3+\n\3\f\3\16\3.\13\3\3\3\5\3\61\n") - buf.write("\3\3\3\3\3\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") - buf.write("\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\6\4H\n\4\r\4\16\4I\3") - buf.write("\4\3\4\3\4\3\4\3\4\3\4\7\4R\n\4\f\4\16\4U\13\4\5\4W\n") - buf.write("\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\5\4d\n") - buf.write("\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\5\4n\n\4\3\4\3\4\3") - buf.write("\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") - buf.write("\5\4\u0080\n\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") - buf.write("\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\7\4\u0096\n\4") - buf.write("\f\4\16\4\u0099\13\4\5\4\u009b\n\4\3\4\7\4\u009e\n\4\f") - buf.write("\4\16\4\u00a1\13\4\3\5\3\5\5\5\u00a5\n\5\3\5\3\5\3\5\3") - buf.write("\5\3\5\5\5\u00ac\n\5\3\5\3\5\3\6\3\6\3\6\5\6\u00b3\n\6") - buf.write("\3\6\3\6\3\6\3\6\3\6\5\6\u00ba\n\6\3\6\3\6\3\7\3\7\3\7") - buf.write("\3\7\3\7\3\7\5\7\u00c4\n\7\3\b\3\b\3\b\7\b\u00c9\n\b\f") - buf.write("\b\16\b\u00cc\13\b\5\b\u00ce\n\b\3\t\3\t\3\t\5\t\u00d3") - buf.write("\n\t\3\n\3\n\3\n\7\n\u00d8\n\n\f\n\16\n\u00db\13\n\5\n") - buf.write("\u00dd\n\n\3\13\3\13\3\13\3\13\3\f\3\f\3\f\3\f\3\f\3\f") - buf.write("\7\f\u00e9\n\f\f\f\16\f\u00ec\13\f\3\f\3\f\5\f\u00f0\n") - buf.write("\f\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\6\r\u00fd") - buf.write("\n\r\r\r\16\r\u00fe\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3") - buf.write("\r\3\r\3\r\3\r\5\r\u010d\n\r\3\r\3\r\3\r\3\r\7\r\u0113") - buf.write("\n\r\f\r\16\r\u0116\13\r\5\r\u0118\n\r\3\r\3\r\3\r\3\r") - buf.write("\3\r\5\r\u011f\n\r\3\16\3\16\3\16\3\16\3\16\3\16\3\16") - buf.write("\3\16\3\16\3\16\3\16\6\16\u012c\n\16\r\16\16\16\u012d") - buf.write("\3\16\3\16\5\16\u0132\n\16\3\17\3\17\3\17\3\17\3\17\5") - buf.write("\17\u0139\n\17\3\20\3\20\3\21\3\21\3\21\3\21\3\22\3\22") - buf.write("\3\22\5\22\u0144\n\22\3\23\3\23\3\23\3\23\5\23\u014a\n") - buf.write("\23\3\23\2\3\6\24\2\4\6\b\n\f\16\20\22\24\26\30\32\34") - buf.write("\36 \"$\2\6\3\2\31\32\3\2\33\34\3\2\35 \3\2!\"\2\u0175") - buf.write("\2&\3\2\2\2\4(\3\2\2\2\6\177\3\2\2\2\b\u00a2\3\2\2\2\n") - buf.write("\u00af\3\2\2\2\f\u00c3\3\2\2\2\16\u00cd\3\2\2\2\20\u00cf") - buf.write("\3\2\2\2\22\u00dc\3\2\2\2\24\u00de\3\2\2\2\26\u00ef\3") - buf.write("\2\2\2\30\u011e\3\2\2\2\32\u0131\3\2\2\2\34\u0138\3\2") - buf.write("\2\2\36\u013a\3\2\2\2 \u013c\3\2\2\2\"\u0143\3\2\2\2$") - buf.write("\u0149\3\2\2\2&\'\7*\2\2\'\3\3\2\2\2(\60\7\25\2\2)+\5") - buf.write("\n\6\2*)\3\2\2\2+.\3\2\2\2,*\3\2\2\2,-\3\2\2\2-\61\3\2") - buf.write("\2\2.,\3\2\2\2/\61\5\6\4\2\60,\3\2\2\2\60/\3\2\2\2\61") - buf.write("\62\3\2\2\2\62\63\7\2\2\3\63\5\3\2\2\2\64\65\b\4\1\2\65") - buf.write("\66\7\3\2\2\66\67\5\6\4\2\678\7\4\2\28\u0080\3\2\2\29") - buf.write(":\7\34\2\2:\u0080\5\6\4\23;\u0080\5\b\5\2<=\7\3\2\2=\u0080") - buf.write("\7\4\2\2>?\7\3\2\2?@\5\6\4\2@A\7\5\2\2AB\7\4\2\2B\u0080") - buf.write("\3\2\2\2CD\7\3\2\2DG\5\6\4\2EF\7\5\2\2FH\5\6\4\2GE\3\2") - buf.write("\2\2HI\3\2\2\2IG\3\2\2\2IJ\3\2\2\2JK\3\2\2\2KL\7\4\2\2") - buf.write("L\u0080\3\2\2\2MV\7\6\2\2NS\5\6\4\2OP\7\5\2\2PR\5\6\4") - buf.write("\2QO\3\2\2\2RU\3\2\2\2SQ\3\2\2\2ST\3\2\2\2TW\3\2\2\2U") - buf.write("S\3\2\2\2VN\3\2\2\2VW\3\2\2\2WX\3\2\2\2X\u0080\7\7\2\2") - buf.write("YZ\7\b\2\2Z[\7\3\2\2[\\\5\6\4\2\\]\7\4\2\2]^\5 \21\2^") - buf.write("_\7\t\2\2_`\5 \21\2`\u0080\3\2\2\2ac\7\n\2\2bd\7&\2\2") - buf.write("cb\3\2\2\2cd\3\2\2\2de\3\2\2\2ef\5\20\t\2fg\7\13\2\2g") - buf.write("h\5\6\4\2hi\7\f\2\2ij\5\6\4\bj\u0080\3\2\2\2km\7\n\2\2") - buf.write("ln\7&\2\2ml\3\2\2\2mn\3\2\2\2no\3\2\2\2op\5\20\t\2pq\7") - buf.write("\13\2\2qr\7\r\2\2rs\5\6\4\2st\7\16\2\2tu\7\f\2\2uv\5\6") - buf.write("\4\7v\u0080\3\2\2\2wx\5$\23\2xy\7\13\2\2yz\5\6\4\2z{\7") - buf.write("\f\2\2{|\5\6\4\5|\u0080\3\2\2\2}\u0080\5$\23\2~\u0080") - buf.write("\5\"\22\2\177\64\3\2\2\2\1779\3\2\2\2\177;\3\2\2\2\177") - buf.write("<\3\2\2\2\177>\3\2\2\2\177C\3\2\2\2\177M\3\2\2\2\177Y") - buf.write("\3\2\2\2\177a\3\2\2\2\177k\3\2\2\2\177w\3\2\2\2\177}\3") - buf.write("\2\2\2\177~\3\2\2\2\u0080\u009f\3\2\2\2\u0081\u0082\f") - buf.write("\22\2\2\u0082\u0083\t\2\2\2\u0083\u009e\5\6\4\23\u0084") - buf.write("\u0085\f\21\2\2\u0085\u0086\t\3\2\2\u0086\u009e\5\6\4") - buf.write("\22\u0087\u0088\f\20\2\2\u0088\u0089\t\4\2\2\u0089\u009e") - buf.write("\5\6\4\21\u008a\u008b\f\17\2\2\u008b\u008c\t\5\2\2\u008c") - buf.write("\u009e\5\6\4\20\u008d\u008e\f\6\2\2\u008e\u008f\7\f\2") - buf.write("\2\u008f\u009e\5\6\4\7\u0090\u0091\f\24\2\2\u0091\u009a") - buf.write("\7\3\2\2\u0092\u0097\5\6\4\2\u0093\u0094\7\5\2\2\u0094") - buf.write("\u0096\5\6\4\2\u0095\u0093\3\2\2\2\u0096\u0099\3\2\2\2") - buf.write("\u0097\u0095\3\2\2\2\u0097\u0098\3\2\2\2\u0098\u009b\3") - buf.write("\2\2\2\u0099\u0097\3\2\2\2\u009a\u0092\3\2\2\2\u009a\u009b") - buf.write("\3\2\2\2\u009b\u009c\3\2\2\2\u009c\u009e\7\4\2\2\u009d") - buf.write("\u0081\3\2\2\2\u009d\u0084\3\2\2\2\u009d\u0087\3\2\2\2") - buf.write("\u009d\u008a\3\2\2\2\u009d\u008d\3\2\2\2\u009d\u0090\3") - buf.write("\2\2\2\u009e\u00a1\3\2\2\2\u009f\u009d\3\2\2\2\u009f\u00a0") - buf.write("\3\2\2\2\u00a0\7\3\2\2\2\u00a1\u009f\3\2\2\2\u00a2\u00a4") - buf.write("\7\17\2\2\u00a3\u00a5\5\26\f\2\u00a4\u00a3\3\2\2\2\u00a4") - buf.write("\u00a5\3\2\2\2\u00a5\u00a6\3\2\2\2\u00a6\u00a7\7\3\2\2") - buf.write("\u00a7\u00a8\5\f\7\2\u00a8\u00ab\7\4\2\2\u00a9\u00aa\7") - buf.write("\20\2\2\u00aa\u00ac\5\30\r\2\u00ab\u00a9\3\2\2\2\u00ab") - buf.write("\u00ac\3\2\2\2\u00ac\u00ad\3\2\2\2\u00ad\u00ae\5 \21\2") - buf.write("\u00ae\t\3\2\2\2\u00af\u00b0\7\21\2\2\u00b0\u00b2\5$\23") - buf.write("\2\u00b1\u00b3\5\26\f\2\u00b2\u00b1\3\2\2\2\u00b2\u00b3") - buf.write("\3\2\2\2\u00b3\u00b4\3\2\2\2\u00b4\u00b5\7\3\2\2\u00b5") - buf.write("\u00b6\5\f\7\2\u00b6\u00b9\7\4\2\2\u00b7\u00b8\7\20\2") - buf.write("\2\u00b8\u00ba\5\30\r\2\u00b9\u00b7\3\2\2\2\u00b9\u00ba") - buf.write("\3\2\2\2\u00ba\u00bb\3\2\2\2\u00bb\u00bc\5 \21\2\u00bc") - buf.write("\13\3\2\2\2\u00bd\u00c4\5\16\b\2\u00be\u00c4\5\22\n\2") - buf.write("\u00bf\u00c0\5\16\b\2\u00c0\u00c1\7\5\2\2\u00c1\u00c2") - buf.write("\5\22\n\2\u00c2\u00c4\3\2\2\2\u00c3\u00bd\3\2\2\2\u00c3") - buf.write("\u00be\3\2\2\2\u00c3\u00bf\3\2\2\2\u00c4\r\3\2\2\2\u00c5") - buf.write("\u00ca\5\20\t\2\u00c6\u00c7\7\5\2\2\u00c7\u00c9\5\20\t") - buf.write("\2\u00c8\u00c6\3\2\2\2\u00c9\u00cc\3\2\2\2\u00ca\u00c8") - buf.write("\3\2\2\2\u00ca\u00cb\3\2\2\2\u00cb\u00ce\3\2\2\2\u00cc") - buf.write("\u00ca\3\2\2\2\u00cd\u00c5\3\2\2\2\u00cd\u00ce\3\2\2\2") - buf.write("\u00ce\17\3\2\2\2\u00cf\u00d2\5$\23\2\u00d0\u00d1\7\22") - buf.write("\2\2\u00d1\u00d3\5\30\r\2\u00d2\u00d0\3\2\2\2\u00d2\u00d3") - buf.write("\3\2\2\2\u00d3\21\3\2\2\2\u00d4\u00d9\5\24\13\2\u00d5") - buf.write("\u00d6\7\5\2\2\u00d6\u00d8\5\24\13\2\u00d7\u00d5\3\2\2") - buf.write("\2\u00d8\u00db\3\2\2\2\u00d9\u00d7\3\2\2\2\u00d9\u00da") - buf.write("\3\2\2\2\u00da\u00dd\3\2\2\2\u00db\u00d9\3\2\2\2\u00dc") - buf.write("\u00d4\3\2\2\2\u00dc\u00dd\3\2\2\2\u00dd\23\3\2\2\2\u00de") - buf.write("\u00df\7*\2\2\u00df\u00e0\7\13\2\2\u00e0\u00e1\5\6\4\2") - buf.write("\u00e1\25\3\2\2\2\u00e2\u00e3\7\6\2\2\u00e3\u00f0\7\7") - buf.write("\2\2\u00e4\u00e5\7\6\2\2\u00e5\u00ea\5$\23\2\u00e6\u00e7") - buf.write("\7\5\2\2\u00e7\u00e9\5$\23\2\u00e8\u00e6\3\2\2\2\u00e9") - buf.write("\u00ec\3\2\2\2\u00ea\u00e8\3\2\2\2\u00ea\u00eb\3\2\2\2") - buf.write("\u00eb\u00ed\3\2\2\2\u00ec\u00ea\3\2\2\2\u00ed\u00ee\7") - buf.write("\7\2\2\u00ee\u00f0\3\2\2\2\u00ef\u00e2\3\2\2\2\u00ef\u00e4") - buf.write("\3\2\2\2\u00f0\27\3\2\2\2\u00f1\u00f2\7\3\2\2\u00f2\u011f") - buf.write("\7\4\2\2\u00f3\u00f4\7\3\2\2\u00f4\u00f5\5\30\r\2\u00f5") - buf.write("\u00f6\7\5\2\2\u00f6\u00f7\7\4\2\2\u00f7\u011f\3\2\2\2") - buf.write("\u00f8\u00f9\7\3\2\2\u00f9\u00fc\5\30\r\2\u00fa\u00fb") - buf.write("\7\5\2\2\u00fb\u00fd\5\30\r\2\u00fc\u00fa\3\2\2\2\u00fd") - buf.write("\u00fe\3\2\2\2\u00fe\u00fc\3\2\2\2\u00fe\u00ff\3\2\2\2") - buf.write("\u00ff\u0100\3\2\2\2\u0100\u0101\7\4\2\2\u0101\u011f\3") - buf.write("\2\2\2\u0102\u011f\5\36\20\2\u0103\u0104\7\23\2\2\u0104") - buf.write("\u0105\7\6\2\2\u0105\u0106\5\32\16\2\u0106\u0107\7\5\2") - buf.write("\2\u0107\u0108\5\30\r\2\u0108\u0109\7\7\2\2\u0109\u011f") - buf.write("\3\2\2\2\u010a\u010c\7\17\2\2\u010b\u010d\5\26\f\2\u010c") - buf.write("\u010b\3\2\2\2\u010c\u010d\3\2\2\2\u010d\u010e\3\2\2\2") - buf.write("\u010e\u0117\7\3\2\2\u010f\u0114\5\30\r\2\u0110\u0111") - buf.write("\7\5\2\2\u0111\u0113\5\30\r\2\u0112\u0110\3\2\2\2\u0113") - buf.write("\u0116\3\2\2\2\u0114\u0112\3\2\2\2\u0114\u0115\3\2\2\2") - buf.write("\u0115\u0118\3\2\2\2\u0116\u0114\3\2\2\2\u0117\u010f\3") - buf.write("\2\2\2\u0117\u0118\3\2\2\2\u0118\u0119\3\2\2\2\u0119\u011a") - buf.write("\7\4\2\2\u011a\u011b\7\20\2\2\u011b\u011f\5\30\r\2\u011c") - buf.write("\u011f\7\24\2\2\u011d\u011f\7)\2\2\u011e\u00f1\3\2\2\2") - buf.write("\u011e\u00f3\3\2\2\2\u011e\u00f8\3\2\2\2\u011e\u0102\3") - buf.write("\2\2\2\u011e\u0103\3\2\2\2\u011e\u010a\3\2\2\2\u011e\u011c") - buf.write("\3\2\2\2\u011e\u011d\3\2\2\2\u011f\31\3\2\2\2\u0120\u0121") - buf.write("\7\3\2\2\u0121\u0132\7\4\2\2\u0122\u0123\7\3\2\2\u0123") - buf.write("\u0124\5\34\17\2\u0124\u0125\7\5\2\2\u0125\u0126\7\4\2") - buf.write("\2\u0126\u0132\3\2\2\2\u0127\u0128\7\3\2\2\u0128\u012b") - buf.write("\5\34\17\2\u0129\u012a\7\5\2\2\u012a\u012c\5\34\17\2\u012b") - buf.write("\u0129\3\2\2\2\u012c\u012d\3\2\2\2\u012d\u012b\3\2\2\2") - buf.write("\u012d\u012e\3\2\2\2\u012e\u012f\3\2\2\2\u012f\u0130\7") - buf.write("\4\2\2\u0130\u0132\3\2\2\2\u0131\u0120\3\2\2\2\u0131\u0122") - buf.write("\3\2\2\2\u0131\u0127\3\2\2\2\u0132\33\3\2\2\2\u0133\u0134") - buf.write("\7\3\2\2\u0134\u0135\5\34\17\2\u0135\u0136\7\4\2\2\u0136") - buf.write("\u0139\3\2\2\2\u0137\u0139\7)\2\2\u0138\u0133\3\2\2\2") - buf.write("\u0138\u0137\3\2\2\2\u0139\35\3\2\2\2\u013a\u013b\7*\2") - buf.write("\2\u013b\37\3\2\2\2\u013c\u013d\7\r\2\2\u013d\u013e\5") - buf.write("\6\4\2\u013e\u013f\7\16\2\2\u013f!\3\2\2\2\u0140\u0144") - buf.write("\7(\2\2\u0141\u0144\7)\2\2\u0142\u0144\7\'\2\2\u0143\u0140") - buf.write("\3\2\2\2\u0143\u0141\3\2\2\2\u0143\u0142\3\2\2\2\u0144") - buf.write("#\3\2\2\2\u0145\u014a\5\2\2\2\u0146\u014a\7#\2\2\u0147") - buf.write("\u014a\7$\2\2\u0148\u014a\7%\2\2\u0149\u0145\3\2\2\2\u0149") - buf.write("\u0146\3\2\2\2\u0149\u0147\3\2\2\2\u0149\u0148\3\2\2\2") - buf.write("\u014a%\3\2\2\2$,\60ISVcm\177\u0097\u009a\u009d\u009f") - buf.write("\u00a4\u00ab\u00b2\u00b9\u00c3\u00ca\u00cd\u00d2\u00d9") - buf.write("\u00dc\u00ea\u00ef\u00fe\u010c\u0114\u0117\u011e\u012d") - buf.write("\u0131\u0138\u0143\u0149") + buf.write("\4\24\t\24\4\25\t\25\4\26\t\26\3\2\3\2\3\3\3\3\7\3\61") + buf.write("\n\3\f\3\16\3\64\13\3\3\3\5\3\67\n\3\3\3\5\3:\n\3\3\3") + buf.write("\3\3\3\4\3\4\3\4\7\4A\n\4\f\4\16\4D\13\4\5\4F\n\4\3\5") + buf.write("\3\5\3\5\3\5\7\5L\n\5\f\5\16\5O\13\5\3\5\5\5R\n\5\3\6") + buf.write("\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3") + buf.write("\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\6\6k\n\6\r\6\16\6l") + buf.write("\3\6\3\6\3\6\3\6\3\6\3\6\7\6u\n\6\f\6\16\6x\13\6\5\6z") + buf.write("\n\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3") + buf.write("\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6") + buf.write("\5\6\u0096\n\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6") + buf.write("\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\3\6\7") + buf.write("\6\u00af\n\6\f\6\16\6\u00b2\13\6\3\7\3\7\5\7\u00b6\n\7") + buf.write("\3\7\3\7\3\7\3\7\3\7\5\7\u00bd\n\7\3\7\3\7\3\b\3\b\3\b") + buf.write("\5\b\u00c4\n\b\3\b\3\b\3\b\3\b\3\b\5\b\u00cb\n\b\3\b\3") + buf.write("\b\3\t\3\t\3\t\3\t\7\t\u00d3\n\t\f\t\16\t\u00d6\13\t\3") + buf.write("\t\5\t\u00d9\n\t\3\n\3\n\3\n\7\n\u00de\n\n\f\n\16\n\u00e1") + buf.write("\13\n\5\n\u00e3\n\n\3\13\3\13\3\13\5\13\u00e8\n\13\3\f") + buf.write("\3\f\3\f\7\f\u00ed\n\f\f\f\16\f\u00f0\13\f\3\r\3\r\3\r") + buf.write("\3\r\3\16\3\16\3\16\3\16\3\16\3\16\7\16\u00fc\n\16\f\16") + buf.write("\16\16\u00ff\13\16\3\16\3\16\5\16\u0103\n\16\3\17\3\17") + buf.write("\3\17\3\17\3\17\3\17\3\17\3\17\3\17\3\17\3\17\6\17\u0110") + buf.write("\n\17\r\17\16\17\u0111\3\17\3\17\3\17\3\17\3\17\3\17\3") + buf.write("\17\3\17\3\17\3\17\3\17\3\17\5\17\u0120\n\17\3\17\3\17") + buf.write("\3\17\3\17\7\17\u0126\n\17\f\17\16\17\u0129\13\17\5\17") + buf.write("\u012b\n\17\3\17\3\17\3\17\3\17\3\17\5\17\u0132\n\17\3") + buf.write("\20\3\20\3\20\3\20\6\20\u0138\n\20\r\20\16\20\u0139\3") + buf.write("\20\3\20\3\20\3\20\3\20\5\20\u0141\n\20\3\21\3\21\3\21") + buf.write("\3\21\3\21\3\21\3\21\3\21\3\22\3\22\3\22\3\22\3\22\3\22") + buf.write("\5\22\u0151\n\22\3\23\3\23\3\24\3\24\3\24\3\24\3\25\3") + buf.write("\25\3\25\5\25\u015c\n\25\3\26\3\26\3\26\3\26\5\26\u0162") + buf.write("\n\26\3\26\2\3\n\27\2\4\6\b\n\f\16\20\22\24\26\30\32\34") + buf.write("\36 \"$&(*\2\6\3\2\35\36\3\2\37 \3\2!$\3\2%&\2\u018e\2") + buf.write(",\3\2\2\2\4.\3\2\2\2\6E\3\2\2\2\bQ\3\2\2\2\n\u0095\3\2") + buf.write("\2\2\f\u00b3\3\2\2\2\16\u00c0\3\2\2\2\20\u00d8\3\2\2\2") + buf.write("\22\u00e2\3\2\2\2\24\u00e4\3\2\2\2\26\u00e9\3\2\2\2\30") + buf.write("\u00f1\3\2\2\2\32\u0102\3\2\2\2\34\u0131\3\2\2\2\36\u0140") + buf.write("\3\2\2\2 \u0142\3\2\2\2\"\u0150\3\2\2\2$\u0152\3\2\2\2") + buf.write("&\u0154\3\2\2\2(\u015b\3\2\2\2*\u0161\3\2\2\2,-\7(\2\2") + buf.write("-\3\3\2\2\2.\66\7\30\2\2/\61\5\16\b\2\60/\3\2\2\2\61\64") + buf.write("\3\2\2\2\62\60\3\2\2\2\62\63\3\2\2\2\63\67\3\2\2\2\64") + buf.write("\62\3\2\2\2\65\67\5\n\6\2\66\62\3\2\2\2\66\65\3\2\2\2") + buf.write("\679\3\2\2\28:\7/\2\298\3\2\2\29:\3\2\2\2:;\3\2\2\2;<") + buf.write("\7\2\2\3<\5\3\2\2\2=B\5\n\6\2>?\7\3\2\2?A\5\n\6\2@>\3") + buf.write("\2\2\2AD\3\2\2\2B@\3\2\2\2BC\3\2\2\2CF\3\2\2\2DB\3\2\2") + buf.write("\2E=\3\2\2\2EF\3\2\2\2F\7\3\2\2\2GR\5\6\4\2HI\5\n\6\2") + buf.write("IJ\7\3\2\2JL\3\2\2\2KH\3\2\2\2LO\3\2\2\2MK\3\2\2\2MN\3") + buf.write("\2\2\2NP\3\2\2\2OM\3\2\2\2PR\5\26\f\2QG\3\2\2\2QM\3\2") + buf.write("\2\2R\t\3\2\2\2ST\b\6\1\2TU\7\4\2\2UV\5\n\6\2VW\7\5\2") + buf.write("\2W\u0096\3\2\2\2XY\7\6\2\2YZ\5\n\6\2Z[\7\7\2\2[\u0096") + buf.write("\3\2\2\2\\]\7 \2\2]\u0096\5\n\6\25^\u0096\5\f\7\2_`\7") + buf.write("\4\2\2`\u0096\7\5\2\2ab\7\4\2\2bc\5\n\6\2cd\7\3\2\2de") + buf.write("\7\5\2\2e\u0096\3\2\2\2fg\7\4\2\2gj\5\n\6\2hi\7\3\2\2") + buf.write("ik\5\n\6\2jh\3\2\2\2kl\3\2\2\2lj\3\2\2\2lm\3\2\2\2mn\3") + buf.write("\2\2\2no\7\5\2\2o\u0096\3\2\2\2py\7\t\2\2qv\5\n\6\2rs") + buf.write("\7\3\2\2su\5\n\6\2tr\3\2\2\2ux\3\2\2\2vt\3\2\2\2vw\3\2") + buf.write("\2\2wz\3\2\2\2xv\3\2\2\2yq\3\2\2\2yz\3\2\2\2z{\3\2\2\2") + buf.write("{\u0096\7\n\2\2|}\7\13\2\2}~\7\4\2\2~\177\5\n\6\2\177") + buf.write("\u0080\7\5\2\2\u0080\u0081\5&\24\2\u0081\u0082\7\f\2\2") + buf.write("\u0082\u0083\5&\24\2\u0083\u0096\3\2\2\2\u0084\u0085\7") + buf.write("\r\2\2\u0085\u0086\5\24\13\2\u0086\u0087\7\16\2\2\u0087") + buf.write("\u0088\5\n\6\2\u0088\u0089\7\17\2\2\u0089\u008a\5\n\6") + buf.write("\t\u008a\u0096\3\2\2\2\u008b\u008c\7+\2\2\u008c\u008d") + buf.write("\7\16\2\2\u008d\u008e\5\n\6\2\u008e\u008f\7\17\2\2\u008f") + buf.write("\u0090\5\n\6\7\u0090\u0096\3\2\2\2\u0091\u0096\5*\26\2") + buf.write("\u0092\u0096\5(\25\2\u0093\u0096\5 \21\2\u0094\u0096\7") + buf.write("\34\2\2\u0095S\3\2\2\2\u0095X\3\2\2\2\u0095\\\3\2\2\2") + buf.write("\u0095^\3\2\2\2\u0095_\3\2\2\2\u0095a\3\2\2\2\u0095f\3") + buf.write("\2\2\2\u0095p\3\2\2\2\u0095|\3\2\2\2\u0095\u0084\3\2\2") + buf.write("\2\u0095\u008b\3\2\2\2\u0095\u0091\3\2\2\2\u0095\u0092") + buf.write("\3\2\2\2\u0095\u0093\3\2\2\2\u0095\u0094\3\2\2\2\u0096") + buf.write("\u00b0\3\2\2\2\u0097\u0098\f\24\2\2\u0098\u0099\t\2\2") + buf.write("\2\u0099\u00af\5\n\6\25\u009a\u009b\f\23\2\2\u009b\u009c") + buf.write("\t\3\2\2\u009c\u00af\5\n\6\24\u009d\u009e\f\22\2\2\u009e") + buf.write("\u009f\t\4\2\2\u009f\u00af\5\n\6\23\u00a0\u00a1\f\21\2") + buf.write("\2\u00a1\u00a2\t\5\2\2\u00a2\u00af\5\n\6\22\u00a3\u00a4") + buf.write("\f\b\2\2\u00a4\u00a5\7\20\2\2\u00a5\u00af\5\n\6\t\u00a6") + buf.write("\u00a7\f\26\2\2\u00a7\u00a8\7\4\2\2\u00a8\u00a9\5\b\5") + buf.write("\2\u00a9\u00aa\7\5\2\2\u00aa\u00af\3\2\2\2\u00ab\u00ac") + buf.write("\f\f\2\2\u00ac\u00ad\7\b\2\2\u00ad\u00af\7.\2\2\u00ae") + buf.write("\u0097\3\2\2\2\u00ae\u009a\3\2\2\2\u00ae\u009d\3\2\2\2") + buf.write("\u00ae\u00a0\3\2\2\2\u00ae\u00a3\3\2\2\2\u00ae\u00a6\3") + buf.write("\2\2\2\u00ae\u00ab\3\2\2\2\u00af\u00b2\3\2\2\2\u00b0\u00ae") + buf.write("\3\2\2\2\u00b0\u00b1\3\2\2\2\u00b1\13\3\2\2\2\u00b2\u00b0") + buf.write("\3\2\2\2\u00b3\u00b5\7\21\2\2\u00b4\u00b6\5\32\16\2\u00b5") + buf.write("\u00b4\3\2\2\2\u00b5\u00b6\3\2\2\2\u00b6\u00b7\3\2\2\2") + buf.write("\u00b7\u00b8\7\4\2\2\u00b8\u00b9\5\20\t\2\u00b9\u00bc") + buf.write("\7\5\2\2\u00ba\u00bb\7\22\2\2\u00bb\u00bd\5\34\17\2\u00bc") + buf.write("\u00ba\3\2\2\2\u00bc\u00bd\3\2\2\2\u00bd\u00be\3\2\2\2") + buf.write("\u00be\u00bf\5&\24\2\u00bf\r\3\2\2\2\u00c0\u00c1\7\23") + buf.write("\2\2\u00c1\u00c3\5*\26\2\u00c2\u00c4\5\32\16\2\u00c3\u00c2") + buf.write("\3\2\2\2\u00c3\u00c4\3\2\2\2\u00c4\u00c5\3\2\2\2\u00c5") + buf.write("\u00c6\7\4\2\2\u00c6\u00c7\5\20\t\2\u00c7\u00ca\7\5\2") + buf.write("\2\u00c8\u00c9\7\22\2\2\u00c9\u00cb\5\34\17\2\u00ca\u00c8") + buf.write("\3\2\2\2\u00ca\u00cb\3\2\2\2\u00cb\u00cc\3\2\2\2\u00cc") + buf.write("\u00cd\5&\24\2\u00cd\17\3\2\2\2\u00ce\u00d9\5\22\n\2\u00cf") + buf.write("\u00d0\5\24\13\2\u00d0\u00d1\7\3\2\2\u00d1\u00d3\3\2\2") + buf.write("\2\u00d2\u00cf\3\2\2\2\u00d3\u00d6\3\2\2\2\u00d4\u00d2") + buf.write("\3\2\2\2\u00d4\u00d5\3\2\2\2\u00d5\u00d7\3\2\2\2\u00d6") + buf.write("\u00d4\3\2\2\2\u00d7\u00d9\5\26\f\2\u00d8\u00ce\3\2\2") + buf.write("\2\u00d8\u00d4\3\2\2\2\u00d9\21\3\2\2\2\u00da\u00df\5") + buf.write("\24\13\2\u00db\u00dc\7\3\2\2\u00dc\u00de\5\24\13\2\u00dd") + buf.write("\u00db\3\2\2\2\u00de\u00e1\3\2\2\2\u00df\u00dd\3\2\2\2") + buf.write("\u00df\u00e0\3\2\2\2\u00e0\u00e3\3\2\2\2\u00e1\u00df\3") + buf.write("\2\2\2\u00e2\u00da\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3\23") + buf.write("\3\2\2\2\u00e4\u00e7\7*\2\2\u00e5\u00e6\7\24\2\2\u00e6") + buf.write("\u00e8\5\34\17\2\u00e7\u00e5\3\2\2\2\u00e7\u00e8\3\2\2") + buf.write("\2\u00e8\25\3\2\2\2\u00e9\u00ee\5\30\r\2\u00ea\u00eb\7") + buf.write("\3\2\2\u00eb\u00ed\5\30\r\2\u00ec\u00ea\3\2\2\2\u00ed") + buf.write("\u00f0\3\2\2\2\u00ee\u00ec\3\2\2\2\u00ee\u00ef\3\2\2\2") + buf.write("\u00ef\27\3\2\2\2\u00f0\u00ee\3\2\2\2\u00f1\u00f2\7(\2") + buf.write("\2\u00f2\u00f3\7\16\2\2\u00f3\u00f4\5\n\6\2\u00f4\31\3") + buf.write("\2\2\2\u00f5\u00f6\7\t\2\2\u00f6\u0103\7\n\2\2\u00f7\u00f8") + buf.write("\7\t\2\2\u00f8\u00fd\5*\26\2\u00f9\u00fa\7\3\2\2\u00fa") + buf.write("\u00fc\5*\26\2\u00fb\u00f9\3\2\2\2\u00fc\u00ff\3\2\2\2") + buf.write("\u00fd\u00fb\3\2\2\2\u00fd\u00fe\3\2\2\2\u00fe\u0100\3") + buf.write("\2\2\2\u00ff\u00fd\3\2\2\2\u0100\u0101\7\n\2\2\u0101\u0103") + buf.write("\3\2\2\2\u0102\u00f5\3\2\2\2\u0102\u00f7\3\2\2\2\u0103") + buf.write("\33\3\2\2\2\u0104\u0105\7\4\2\2\u0105\u0132\7\5\2\2\u0106") + buf.write("\u0107\7\4\2\2\u0107\u0108\5\34\17\2\u0108\u0109\7\3\2") + buf.write("\2\u0109\u010a\7\5\2\2\u010a\u0132\3\2\2\2\u010b\u010c") + buf.write("\7\4\2\2\u010c\u010f\5\34\17\2\u010d\u010e\7\3\2\2\u010e") + buf.write("\u0110\5\34\17\2\u010f\u010d\3\2\2\2\u0110\u0111\3\2\2") + buf.write("\2\u0111\u010f\3\2\2\2\u0111\u0112\3\2\2\2\u0112\u0113") + buf.write("\3\2\2\2\u0113\u0114\7\5\2\2\u0114\u0132\3\2\2\2\u0115") + buf.write("\u0132\5$\23\2\u0116\u0117\7\25\2\2\u0117\u0118\7\t\2") + buf.write("\2\u0118\u0119\5\36\20\2\u0119\u011a\7\3\2\2\u011a\u011b") + buf.write("\5\34\17\2\u011b\u011c\7\n\2\2\u011c\u0132\3\2\2\2\u011d") + buf.write("\u011f\7\21\2\2\u011e\u0120\5\32\16\2\u011f\u011e\3\2") + buf.write("\2\2\u011f\u0120\3\2\2\2\u0120\u0121\3\2\2\2\u0121\u012a") + buf.write("\7\4\2\2\u0122\u0127\5\34\17\2\u0123\u0124\7\3\2\2\u0124") + buf.write("\u0126\5\34\17\2\u0125\u0123\3\2\2\2\u0126\u0129\3\2\2") + buf.write("\2\u0127\u0125\3\2\2\2\u0127\u0128\3\2\2\2\u0128\u012b") + buf.write("\3\2\2\2\u0129\u0127\3\2\2\2\u012a\u0122\3\2\2\2\u012a") + buf.write("\u012b\3\2\2\2\u012b\u012c\3\2\2\2\u012c\u012d\7\5\2\2") + buf.write("\u012d\u012e\7\22\2\2\u012e\u0132\5\34\17\2\u012f\u0132") + buf.write("\7\26\2\2\u0130\u0132\7.\2\2\u0131\u0104\3\2\2\2\u0131") + buf.write("\u0106\3\2\2\2\u0131\u010b\3\2\2\2\u0131\u0115\3\2\2\2") + buf.write("\u0131\u0116\3\2\2\2\u0131\u011d\3\2\2\2\u0131\u012f\3") + buf.write("\2\2\2\u0131\u0130\3\2\2\2\u0132\35\3\2\2\2\u0133\u0134") + buf.write("\7\4\2\2\u0134\u0137\5\"\22\2\u0135\u0136\7\3\2\2\u0136") + buf.write("\u0138\5\"\22\2\u0137\u0135\3\2\2\2\u0138\u0139\3\2\2") + buf.write("\2\u0139\u0137\3\2\2\2\u0139\u013a\3\2\2\2\u013a\u013b") + buf.write("\3\2\2\2\u013b\u013c\7\5\2\2\u013c\u0141\3\2\2\2\u013d") + buf.write("\u013e\7\4\2\2\u013e\u0141\7\5\2\2\u013f\u0141\5\"\22") + buf.write("\2\u0140\u0133\3\2\2\2\u0140\u013d\3\2\2\2\u0140\u013f") + buf.write("\3\2\2\2\u0141\37\3\2\2\2\u0142\u0143\7\27\2\2\u0143\u0144") + buf.write("\7\t\2\2\u0144\u0145\7(\2\2\u0145\u0146\7\n\2\2\u0146") + buf.write("\u0147\7\t\2\2\u0147\u0148\7.\2\2\u0148\u0149\7\n\2\2") + buf.write("\u0149!\3\2\2\2\u014a\u0151\5 \21\2\u014b\u014c\7\4\2") + buf.write("\2\u014c\u014d\5\"\22\2\u014d\u014e\7\5\2\2\u014e\u0151") + buf.write("\3\2\2\2\u014f\u0151\7.\2\2\u0150\u014a\3\2\2\2\u0150") + buf.write("\u014b\3\2\2\2\u0150\u014f\3\2\2\2\u0151#\3\2\2\2\u0152") + buf.write("\u0153\7(\2\2\u0153%\3\2\2\2\u0154\u0155\7\6\2\2\u0155") + buf.write("\u0156\5\n\6\2\u0156\u0157\7\7\2\2\u0157\'\3\2\2\2\u0158") + buf.write("\u015c\7-\2\2\u0159\u015c\7.\2\2\u015a\u015c\7\'\2\2\u015b") + buf.write("\u0158\3\2\2\2\u015b\u0159\3\2\2\2\u015b\u015a\3\2\2\2") + buf.write("\u015c)\3\2\2\2\u015d\u0162\5\2\2\2\u015e\u0162\7)\2\2") + buf.write("\u015f\u0162\7*\2\2\u0160\u0162\7+\2\2\u0161\u015d\3\2") + buf.write("\2\2\u0161\u015e\3\2\2\2\u0161\u015f\3\2\2\2\u0161\u0160") + buf.write("\3\2\2\2\u0162+\3\2\2\2%\62\669BEMQlvy\u0095\u00ae\u00b0") + buf.write("\u00b5\u00bc\u00c3\u00ca\u00d4\u00d8\u00df\u00e2\u00e7") + buf.write("\u00ee\u00fd\u0102\u0111\u011f\u0127\u012a\u0131\u0139") + buf.write("\u0140\u0150\u015b\u0161") return buf.getvalue() @@ -174,46 +187,52 @@ class RelayParser ( Parser ): sharedContextCache = PredictionContextCache() - literalNames = [ "", "'('", "')'", "','", "'['", "']'", "'if'", - "'else'", "'let'", "'='", "';'", "'{'", "'}'", "'fn'", - "'->'", "'def'", "':'", "'Tensor'", "'_'", "'v0.0.3'", - "", "", "", "'*'", "'/'", - "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", - "'!='", "", "", "", "'mut'" ] + literalNames = [ "", "','", "'('", "')'", "'{'", "'}'", "'.'", + "'['", "']'", "'if'", "'else'", "'let'", "'='", "';'", + "';;'", "'fn'", "'->'", "'def'", "':'", "'Tensor'", + "'_'", "'meta'", "'v0.0.3'", "", "", + "", "", "'*'", "'/'", "'+'", "'-'", + "'<'", "'>'", "'<='", "'>='", "'=='", "'!='", "", + "", "", "", "", + "'int64'" ] symbolicNames = [ "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", - "", "", "", "SEMVER", "WS", - "LINE_COMMENT", "COMMENT", "MUL", "DIV", "ADD", "SUB", - "LT", "GT", "LE", "GE", "EQ", "NE", "GLOBAL_VAR", - "LOCAL_VAR", "GRAPH_VAR", "MUT", "BOOL_LIT", "FLOAT", - "NAT", "CNAME" ] + "", "", "", "", + "", "", "SEMVER", "COMMENT", "WS", + "LINE_COMMENT", "QUOTED_STRING", "MUL", "DIV", "ADD", + "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "BOOL_LIT", + "CNAME", "GLOBAL_VAR", "LOCAL_VAR", "GRAPH_VAR", "DATATYPE", + "FLOAT", "NAT", "METADATA" ] RULE_opIdent = 0 RULE_prog = 1 - RULE_expr = 2 - RULE_func = 3 - RULE_defn = 4 - RULE_argList = 5 - RULE_varList = 6 - RULE_var = 7 - RULE_attrList = 8 - RULE_attr = 9 - RULE_typeParamSeq = 10 - RULE_type_ = 11 - RULE_shapeSeq = 12 - RULE_shape = 13 - RULE_typeIdent = 14 - RULE_body = 15 - RULE_scalar = 16 - RULE_ident = 17 - - ruleNames = [ "opIdent", "prog", "expr", "func", "defn", "argList", - "varList", "var", "attrList", "attr", "typeParamSeq", - "type_", "shapeSeq", "shape", "typeIdent", "body", "scalar", - "ident" ] + RULE_exprList = 2 + RULE_callList = 3 + RULE_expr = 4 + RULE_func = 5 + RULE_defn = 6 + RULE_argList = 7 + RULE_varList = 8 + RULE_var = 9 + RULE_attrSeq = 10 + RULE_attr = 11 + RULE_typeParamList = 12 + RULE_type_ = 13 + RULE_shapeList = 14 + RULE_meta = 15 + RULE_shape = 16 + RULE_typeIdent = 17 + RULE_body = 18 + RULE_scalar = 19 + RULE_ident = 20 + + ruleNames = [ "opIdent", "prog", "exprList", "callList", "expr", "func", + "defn", "argList", "varList", "var", "attrSeq", "attr", + "typeParamList", "type_", "shapeList", "meta", "shape", + "typeIdent", "body", "scalar", "ident" ] EOF = Token.EOF T__0=1 @@ -234,28 +253,33 @@ class RelayParser ( Parser ): T__15=16 T__16=17 T__17=18 - SEMVER=19 - WS=20 - LINE_COMMENT=21 - COMMENT=22 - MUL=23 - DIV=24 - ADD=25 - SUB=26 - LT=27 - GT=28 - LE=29 - GE=30 - EQ=31 - NE=32 - GLOBAL_VAR=33 - LOCAL_VAR=34 - GRAPH_VAR=35 - MUT=36 + T__18=19 + T__19=20 + T__20=21 + SEMVER=22 + COMMENT=23 + WS=24 + LINE_COMMENT=25 + QUOTED_STRING=26 + MUL=27 + DIV=28 + ADD=29 + SUB=30 + LT=31 + GT=32 + LE=33 + GE=34 + EQ=35 + NE=36 BOOL_LIT=37 - FLOAT=38 - NAT=39 - CNAME=40 + CNAME=38 + GLOBAL_VAR=39 + LOCAL_VAR=40 + GRAPH_VAR=41 + DATATYPE=42 + FLOAT=43 + NAT=44 + METADATA=45 def __init__(self, input:TokenStream, output:TextIO = sys.stdout): super().__init__(input, output) @@ -292,7 +316,7 @@ def opIdent(self): self.enterRule(localctx, 0, self.RULE_opIdent) try: self.enterOuterAlt(localctx, 1) - self.state = 36 + self.state = 42 self.match(RelayParser.CNAME) except RecognitionException as re: localctx.exception = re @@ -318,6 +342,9 @@ def expr(self): return self.getTypedRuleContext(RelayParser.ExprContext,0) + def METADATA(self): + return self.getToken(RelayParser.METADATA, 0) + def defn(self, i:int=None): if i is None: return self.getTypedRuleContexts(RelayParser.DefnContext) @@ -344,31 +371,39 @@ def prog(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 38 + self.state = 44 self.match(RelayParser.SEMVER) - self.state = 46 + self.state = 52 self._errHandler.sync(self) token = self._input.LA(1) - if token in [RelayParser.EOF, RelayParser.T__14]: - self.state = 42 + if token in [RelayParser.EOF, RelayParser.T__16, RelayParser.METADATA]: + self.state = 48 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==RelayParser.T__14: - self.state = 39 + while _la==RelayParser.T__16: + self.state = 45 self.defn() - self.state = 44 + self.state = 50 self._errHandler.sync(self) _la = self._input.LA(1) pass - elif token in [RelayParser.T__0, RelayParser.T__3, RelayParser.T__5, RelayParser.T__7, RelayParser.T__12, RelayParser.SUB, RelayParser.GLOBAL_VAR, RelayParser.LOCAL_VAR, RelayParser.GRAPH_VAR, RelayParser.BOOL_LIT, RelayParser.FLOAT, RelayParser.NAT, RelayParser.CNAME]: - self.state = 45 + elif token in [RelayParser.T__1, RelayParser.T__3, RelayParser.T__6, RelayParser.T__8, RelayParser.T__10, RelayParser.T__14, RelayParser.T__20, RelayParser.QUOTED_STRING, RelayParser.SUB, RelayParser.BOOL_LIT, RelayParser.CNAME, RelayParser.GLOBAL_VAR, RelayParser.LOCAL_VAR, RelayParser.GRAPH_VAR, RelayParser.FLOAT, RelayParser.NAT]: + self.state = 51 self.expr(0) pass else: raise NoViableAltException(self) - self.state = 48 + self.state = 55 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.METADATA: + self.state = 54 + self.match(RelayParser.METADATA) + + + self.state = 57 self.match(RelayParser.EOF) except RecognitionException as re: localctx.exception = re @@ -378,6 +413,167 @@ def prog(self): self.exitRule() return localctx + class ExprListContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_exprList + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitExprList" ): + return visitor.visitExprList(self) + else: + return visitor.visitChildren(self) + + + + + def exprList(self): + + localctx = RelayParser.ExprListContext(self, self._ctx, self.state) + self.enterRule(localctx, 4, self.RULE_exprList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 67 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__1) | (1 << RelayParser.T__3) | (1 << RelayParser.T__6) | (1 << RelayParser.T__8) | (1 << RelayParser.T__10) | (1 << RelayParser.T__14) | (1 << RelayParser.T__20) | (1 << RelayParser.QUOTED_STRING) | (1 << RelayParser.SUB) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.CNAME) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT))) != 0): + self.state = 59 + self.expr(0) + self.state = 64 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__0: + self.state = 60 + self.match(RelayParser.T__0) + self.state = 61 + self.expr(0) + self.state = 66 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + class CallListContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_callList + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class CallWithAttrContext(CallListContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.CallListContext + super().__init__(parser) + self.copyFrom(ctx) + + def attrSeq(self): + return self.getTypedRuleContext(RelayParser.AttrSeqContext,0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitCallWithAttr" ): + return visitor.visitCallWithAttr(self) + else: + return visitor.visitChildren(self) + + + class CallNoAttrContext(CallListContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.CallListContext + super().__init__(parser) + self.copyFrom(ctx) + + def exprList(self): + return self.getTypedRuleContext(RelayParser.ExprListContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitCallNoAttr" ): + return visitor.visitCallNoAttr(self) + else: + return visitor.visitChildren(self) + + + + def callList(self): + + localctx = RelayParser.CallListContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_callList) + try: + self.state = 79 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,6,self._ctx) + if la_ == 1: + localctx = RelayParser.CallNoAttrContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 69 + self.exprList() + pass + + elif la_ == 2: + localctx = RelayParser.CallWithAttrContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 75 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,5,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 70 + self.expr(0) + self.state = 71 + self.match(RelayParser.T__0) + self.state = 77 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,5,self._ctx) + + self.state = 78 + self.attrSeq() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + class ExprContext(ParserRuleContext): def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): @@ -393,6 +589,82 @@ def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) + class FuncExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def func(self): + return self.getTypedRuleContext(RelayParser.FuncContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFuncExpr" ): + return visitor.visitFuncExpr(self) + else: + return visitor.visitChildren(self) + + + class MetaExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def meta(self): + return self.getTypedRuleContext(RelayParser.MetaContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitMetaExpr" ): + return visitor.visitMetaExpr(self) + else: + return visitor.visitChildren(self) + + + class TensorContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTensor" ): + return visitor.visitTensor(self) + else: + return visitor.visitChildren(self) + + + class GraphContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def GRAPH_VAR(self): + return self.getToken(RelayParser.GRAPH_VAR, 0) + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitGraph" ): + return visitor.visitGraph(self) + else: + return visitor.visitChildren(self) + + class IdentExprContext(ExprContext): def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext @@ -410,17 +682,33 @@ def accept(self, visitor:ParseTreeVisitor): return visitor.visitChildren(self) - class CallContext(ExprContext): + class StringExprContext(ExprContext): def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext super().__init__(parser) self.copyFrom(ctx) - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) + def QUOTED_STRING(self): + return self.getToken(RelayParser.QUOTED_STRING, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitStringExpr" ): + return visitor.visitStringExpr(self) else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) + return visitor.visitChildren(self) + + + class CallContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + def callList(self): + return self.getTypedRuleContext(RelayParser.CallListContext,0) def accept(self, visitor:ParseTreeVisitor): @@ -467,7 +755,7 @@ def accept(self, visitor:ParseTreeVisitor): return visitor.visitChildren(self) - class ParensContext(ExprContext): + class ParenContext(ExprContext): def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext super().__init__(parser) @@ -478,25 +766,8 @@ def expr(self): def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitParens" ): - return visitor.visitParens(self) - else: - return visitor.visitChildren(self) - - - class FuncExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def func(self): - return self.getTypedRuleContext(RelayParser.FuncContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitFuncExpr" ): - return visitor.visitFuncExpr(self) + if hasattr( visitor, "visitParen" ): + return visitor.visitParen(self) else: return visitor.visitChildren(self) @@ -533,8 +804,6 @@ def expr(self, i:int=None): else: return self.getTypedRuleContext(RelayParser.ExprContext,i) - def MUT(self): - return self.getToken(RelayParser.MUT, 0) def accept(self, visitor:ParseTreeVisitor): if hasattr( visitor, "visitLet" ): @@ -543,22 +812,21 @@ def accept(self, visitor:ParseTreeVisitor): return visitor.visitChildren(self) - class TensorContext(ExprContext): + class ProjectionContext(ExprContext): def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext super().__init__(parser) self.copyFrom(ctx) - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + def NAT(self): + return self.getToken(RelayParser.NAT, 0) def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTensor" ): - return visitor.visitTensor(self) + if hasattr( visitor, "visitProjection" ): + return visitor.visitProjection(self) else: return visitor.visitChildren(self) @@ -586,29 +854,6 @@ def accept(self, visitor:ParseTreeVisitor): return visitor.visitChildren(self) - class GraphContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def ident(self): - return self.getTypedRuleContext(RelayParser.IdentContext,0) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitGraph" ): - return visitor.visitGraph(self) - else: - return visitor.visitChildren(self) - - class BinOpContext(ExprContext): def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext @@ -636,222 +881,196 @@ def expr(self, _p:int=0): _parentState = self.state localctx = RelayParser.ExprContext(self, self._ctx, _parentState) _prevctx = localctx - _startState = 4 - self.enterRecursionRule(localctx, 4, self.RULE_expr, _p) + _startState = 8 + self.enterRecursionRule(localctx, 8, self.RULE_expr, _p) self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 125 + self.state = 147 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,7,self._ctx) + la_ = self._interp.adaptivePredict(self._input,10,self._ctx) if la_ == 1: - localctx = RelayParser.ParensContext(self, localctx) + localctx = RelayParser.ParenContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 51 - self.match(RelayParser.T__0) - self.state = 52 - self.expr(0) - self.state = 53 + self.state = 82 self.match(RelayParser.T__1) + self.state = 83 + self.expr(0) + self.state = 84 + self.match(RelayParser.T__2) pass elif la_ == 2: + localctx = RelayParser.ParenContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 86 + self.match(RelayParser.T__3) + self.state = 87 + self.expr(0) + self.state = 88 + self.match(RelayParser.T__4) + pass + + elif la_ == 3: localctx = RelayParser.NegContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 55 + self.state = 90 self.match(RelayParser.SUB) - self.state = 56 - self.expr(17) + self.state = 91 + self.expr(19) pass - elif la_ == 3: + elif la_ == 4: localctx = RelayParser.FuncExprContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 57 + self.state = 92 self.func() pass - elif la_ == 4: + elif la_ == 5: localctx = RelayParser.TupleContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 58 - self.match(RelayParser.T__0) - self.state = 59 + self.state = 93 self.match(RelayParser.T__1) + self.state = 94 + self.match(RelayParser.T__2) pass - elif la_ == 5: + elif la_ == 6: localctx = RelayParser.TupleContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 60 - self.match(RelayParser.T__0) - self.state = 61 + self.state = 95 + self.match(RelayParser.T__1) + self.state = 96 self.expr(0) - self.state = 62 + self.state = 97 + self.match(RelayParser.T__0) + self.state = 98 self.match(RelayParser.T__2) - self.state = 63 - self.match(RelayParser.T__1) pass - elif la_ == 6: + elif la_ == 7: localctx = RelayParser.TupleContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 65 - self.match(RelayParser.T__0) - self.state = 66 + self.state = 100 + self.match(RelayParser.T__1) + self.state = 101 self.expr(0) - self.state = 69 + self.state = 104 self._errHandler.sync(self) _la = self._input.LA(1) while True: - self.state = 67 - self.match(RelayParser.T__2) - self.state = 68 + self.state = 102 + self.match(RelayParser.T__0) + self.state = 103 self.expr(0) - self.state = 71 + self.state = 106 self._errHandler.sync(self) _la = self._input.LA(1) - if not (_la==RelayParser.T__2): + if not (_la==RelayParser.T__0): break - self.state = 73 - self.match(RelayParser.T__1) + self.state = 108 + self.match(RelayParser.T__2) pass - elif la_ == 7: + elif la_ == 8: localctx = RelayParser.TensorContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 75 - self.match(RelayParser.T__3) - self.state = 84 + self.state = 110 + self.match(RelayParser.T__6) + self.state = 119 self._errHandler.sync(self) _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): - self.state = 76 + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__1) | (1 << RelayParser.T__3) | (1 << RelayParser.T__6) | (1 << RelayParser.T__8) | (1 << RelayParser.T__10) | (1 << RelayParser.T__14) | (1 << RelayParser.T__20) | (1 << RelayParser.QUOTED_STRING) | (1 << RelayParser.SUB) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.CNAME) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT))) != 0): + self.state = 111 self.expr(0) - self.state = 81 + self.state = 116 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==RelayParser.T__2: - self.state = 77 - self.match(RelayParser.T__2) - self.state = 78 + while _la==RelayParser.T__0: + self.state = 112 + self.match(RelayParser.T__0) + self.state = 113 self.expr(0) - self.state = 83 + self.state = 118 self._errHandler.sync(self) _la = self._input.LA(1) - self.state = 86 - self.match(RelayParser.T__4) - pass - - elif la_ == 8: - localctx = RelayParser.IfElseContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 87 - self.match(RelayParser.T__5) - self.state = 88 - self.match(RelayParser.T__0) - self.state = 89 - self.expr(0) - self.state = 90 - self.match(RelayParser.T__1) - self.state = 91 - self.body() - self.state = 92 - self.match(RelayParser.T__6) - self.state = 93 - self.body() + self.state = 121 + self.match(RelayParser.T__7) pass elif la_ == 9: - localctx = RelayParser.LetContext(self, localctx) + localctx = RelayParser.IfElseContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 95 - self.match(RelayParser.T__7) - self.state = 97 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.MUT: - self.state = 96 - self.match(RelayParser.MUT) - - - self.state = 99 - self.var() - self.state = 100 + self.state = 122 self.match(RelayParser.T__8) - self.state = 101 + self.state = 123 + self.match(RelayParser.T__1) + self.state = 124 self.expr(0) - self.state = 102 + self.state = 125 + self.match(RelayParser.T__2) + self.state = 126 + self.body() + self.state = 127 self.match(RelayParser.T__9) - self.state = 103 - self.expr(6) + self.state = 128 + self.body() pass elif la_ == 10: localctx = RelayParser.LetContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 105 - self.match(RelayParser.T__7) - self.state = 107 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.MUT: - self.state = 106 - self.match(RelayParser.MUT) - - - self.state = 109 - self.var() - self.state = 110 - self.match(RelayParser.T__8) - self.state = 111 + self.state = 130 self.match(RelayParser.T__10) - self.state = 112 - self.expr(0) - self.state = 113 + self.state = 131 + self.var() + self.state = 132 self.match(RelayParser.T__11) - self.state = 114 - self.match(RelayParser.T__9) - self.state = 115 - self.expr(5) + self.state = 133 + self.expr(0) + self.state = 134 + self.match(RelayParser.T__12) + self.state = 135 + self.expr(7) pass elif la_ == 11: localctx = RelayParser.GraphContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 117 - self.ident() - self.state = 118 - self.match(RelayParser.T__8) - self.state = 119 + self.state = 137 + self.match(RelayParser.GRAPH_VAR) + self.state = 138 + self.match(RelayParser.T__11) + self.state = 139 self.expr(0) - self.state = 120 - self.match(RelayParser.T__9) - self.state = 121 - self.expr(3) + self.state = 140 + self.match(RelayParser.T__12) + self.state = 141 + self.expr(5) pass elif la_ == 12: localctx = RelayParser.IdentExprContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 123 + self.state = 143 self.ident() pass @@ -859,31 +1078,47 @@ def expr(self, _p:int=0): localctx = RelayParser.ScalarExprContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 124 + self.state = 144 self.scalar() pass + elif la_ == 14: + localctx = RelayParser.MetaExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 145 + self.meta() + pass + + elif la_ == 15: + localctx = RelayParser.StringExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 146 + self.match(RelayParser.QUOTED_STRING) + pass + self._ctx.stop = self._input.LT(-1) - self.state = 157 + self.state = 174 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + _alt = self._interp.adaptivePredict(self._input,12,self._ctx) while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: if _alt==1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 155 + self.state = 172 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,10,self._ctx) + la_ = self._interp.adaptivePredict(self._input,11,self._ctx) if la_ == 1: localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 127 - if not self.precpred(self._ctx, 16): + self.state = 149 + if not self.precpred(self._ctx, 18): from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") - self.state = 128 + raise FailedPredicateException(self, "self.precpred(self._ctx, 18)") + self.state = 150 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==RelayParser.MUL or _la==RelayParser.DIV): @@ -891,18 +1126,18 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 129 - self.expr(17) + self.state = 151 + self.expr(19) pass elif la_ == 2: localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 130 - if not self.precpred(self._ctx, 15): + self.state = 152 + if not self.precpred(self._ctx, 17): from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 15)") - self.state = 131 + raise FailedPredicateException(self, "self.precpred(self._ctx, 17)") + self.state = 153 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==RelayParser.ADD or _la==RelayParser.SUB): @@ -910,18 +1145,18 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 132 - self.expr(16) + self.state = 154 + self.expr(18) pass elif la_ == 3: localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 133 - if not self.precpred(self._ctx, 14): + self.state = 155 + if not self.precpred(self._ctx, 16): from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 14)") - self.state = 134 + raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") + self.state = 156 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not((((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.LT) | (1 << RelayParser.GT) | (1 << RelayParser.LE) | (1 << RelayParser.GE))) != 0)): @@ -929,18 +1164,18 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 135 - self.expr(15) + self.state = 157 + self.expr(17) pass elif la_ == 4: localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 136 - if not self.precpred(self._ctx, 13): + self.state = 158 + if not self.precpred(self._ctx, 15): from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 13)") - self.state = 137 + raise FailedPredicateException(self, "self.precpred(self._ctx, 15)") + self.state = 159 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==RelayParser.EQ or _la==RelayParser.NE): @@ -948,60 +1183,55 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 138 - self.expr(14) + self.state = 160 + self.expr(16) pass elif la_ == 5: localctx = RelayParser.LetContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 139 - if not self.precpred(self._ctx, 4): + self.state = 161 + if not self.precpred(self._ctx, 6): from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") - self.state = 140 - self.match(RelayParser.T__9) - self.state = 141 - self.expr(5) + raise FailedPredicateException(self, "self.precpred(self._ctx, 6)") + self.state = 162 + self.match(RelayParser.T__13) + self.state = 163 + self.expr(7) pass elif la_ == 6: localctx = RelayParser.CallContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 142 - if not self.precpred(self._ctx, 18): + self.state = 164 + if not self.precpred(self._ctx, 20): from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 18)") - self.state = 143 - self.match(RelayParser.T__0) - self.state = 152 - self._errHandler.sync(self) - _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): - self.state = 144 - self.expr(0) - self.state = 149 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__2: - self.state = 145 - self.match(RelayParser.T__2) - self.state = 146 - self.expr(0) - self.state = 151 - self._errHandler.sync(self) - _la = self._input.LA(1) - - - - self.state = 154 + raise FailedPredicateException(self, "self.precpred(self._ctx, 20)") + self.state = 165 self.match(RelayParser.T__1) + self.state = 166 + self.callList() + self.state = 167 + self.match(RelayParser.T__2) + pass + + elif la_ == 7: + localctx = RelayParser.ProjectionContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 169 + if not self.precpred(self._ctx, 10): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 10)") + self.state = 170 + self.match(RelayParser.T__5) + self.state = 171 + self.match(RelayParser.NAT) pass - self.state = 159 + self.state = 176 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + _alt = self._interp.adaptivePredict(self._input,12,self._ctx) except RecognitionException as re: localctx.exception = re @@ -1025,8 +1255,8 @@ def body(self): return self.getTypedRuleContext(RelayParser.BodyContext,0) - def typeParamSeq(self): - return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + def typeParamList(self): + return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) def type_(self): @@ -1048,37 +1278,37 @@ def accept(self, visitor:ParseTreeVisitor): def func(self): localctx = RelayParser.FuncContext(self, self._ctx, self.state) - self.enterRule(localctx, 6, self.RULE_func) + self.enterRule(localctx, 10, self.RULE_func) self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 160 - self.match(RelayParser.T__12) - self.state = 162 + self.state = 177 + self.match(RelayParser.T__14) + self.state = 179 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==RelayParser.T__3: - self.state = 161 - self.typeParamSeq() + if _la==RelayParser.T__6: + self.state = 178 + self.typeParamList() - self.state = 164 - self.match(RelayParser.T__0) - self.state = 165 - self.argList() - self.state = 166 + self.state = 181 self.match(RelayParser.T__1) - self.state = 169 + self.state = 182 + self.argList() + self.state = 183 + self.match(RelayParser.T__2) + self.state = 186 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==RelayParser.T__13: - self.state = 167 - self.match(RelayParser.T__13) - self.state = 168 + if _la==RelayParser.T__15: + self.state = 184 + self.match(RelayParser.T__15) + self.state = 185 self.type_() - self.state = 171 + self.state = 188 self.body() except RecognitionException as re: localctx.exception = re @@ -1106,8 +1336,8 @@ def body(self): return self.getTypedRuleContext(RelayParser.BodyContext,0) - def typeParamSeq(self): - return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + def typeParamList(self): + return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) def type_(self): @@ -1129,39 +1359,39 @@ def accept(self, visitor:ParseTreeVisitor): def defn(self): localctx = RelayParser.DefnContext(self, self._ctx, self.state) - self.enterRule(localctx, 8, self.RULE_defn) + self.enterRule(localctx, 12, self.RULE_defn) self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 173 - self.match(RelayParser.T__14) - self.state = 174 + self.state = 190 + self.match(RelayParser.T__16) + self.state = 191 self.ident() - self.state = 176 + self.state = 193 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==RelayParser.T__3: - self.state = 175 - self.typeParamSeq() + if _la==RelayParser.T__6: + self.state = 192 + self.typeParamList() - self.state = 178 - self.match(RelayParser.T__0) - self.state = 179 - self.argList() - self.state = 180 + self.state = 195 self.match(RelayParser.T__1) - self.state = 183 + self.state = 196 + self.argList() + self.state = 197 + self.match(RelayParser.T__2) + self.state = 200 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==RelayParser.T__13: - self.state = 181 - self.match(RelayParser.T__13) - self.state = 182 + if _la==RelayParser.T__15: + self.state = 198 + self.match(RelayParser.T__15) + self.state = 199 self.type_() - self.state = 185 + self.state = 202 self.body() except RecognitionException as re: localctx.exception = re @@ -1177,54 +1407,90 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + + def getRuleIndex(self): + return RelayParser.RULE_argList + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class ArgNoAttrContext(ArgListContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ArgListContext + super().__init__(parser) + self.copyFrom(ctx) + def varList(self): return self.getTypedRuleContext(RelayParser.VarListContext,0) - def attrList(self): - return self.getTypedRuleContext(RelayParser.AttrListContext,0) + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitArgNoAttr" ): + return visitor.visitArgNoAttr(self) + else: + return visitor.visitChildren(self) + + class ArgWithAttrContext(ArgListContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ArgListContext + super().__init__(parser) + self.copyFrom(ctx) + + def attrSeq(self): + return self.getTypedRuleContext(RelayParser.AttrSeqContext,0) + + def var(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.VarContext) + else: + return self.getTypedRuleContext(RelayParser.VarContext,i) - def getRuleIndex(self): - return RelayParser.RULE_argList def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitArgList" ): - return visitor.visitArgList(self) + if hasattr( visitor, "visitArgWithAttr" ): + return visitor.visitArgWithAttr(self) else: return visitor.visitChildren(self) - def argList(self): localctx = RelayParser.ArgListContext(self, self._ctx, self.state) - self.enterRule(localctx, 10, self.RULE_argList) + self.enterRule(localctx, 14, self.RULE_argList) + self._la = 0 # Token type try: - self.state = 193 + self.state = 214 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,16,self._ctx) + la_ = self._interp.adaptivePredict(self._input,18,self._ctx) if la_ == 1: + localctx = RelayParser.ArgNoAttrContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 187 + self.state = 204 self.varList() pass elif la_ == 2: + localctx = RelayParser.ArgWithAttrContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 188 - self.attrList() - pass + self.state = 210 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.LOCAL_VAR: + self.state = 205 + self.var() + self.state = 206 + self.match(RelayParser.T__0) + self.state = 212 + self._errHandler.sync(self) + _la = self._input.LA(1) - elif la_ == 3: - self.enterOuterAlt(localctx, 3) - self.state = 189 - self.varList() - self.state = 190 - self.match(RelayParser.T__2) - self.state = 191 - self.attrList() + self.state = 213 + self.attrSeq() pass @@ -1264,28 +1530,27 @@ def accept(self, visitor:ParseTreeVisitor): def varList(self): localctx = RelayParser.VarListContext(self, self._ctx, self.state) - self.enterRule(localctx, 12, self.RULE_varList) + self.enterRule(localctx, 16, self.RULE_varList) self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 203 + self.state = 224 self._errHandler.sync(self) _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.CNAME))) != 0): - self.state = 195 + if _la==RelayParser.LOCAL_VAR: + self.state = 216 self.var() - self.state = 200 + self.state = 221 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,17,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - self.state = 196 - self.match(RelayParser.T__2) - self.state = 197 - self.var() - self.state = 202 + _la = self._input.LA(1) + while _la==RelayParser.T__0: + self.state = 217 + self.match(RelayParser.T__0) + self.state = 218 + self.var() + self.state = 223 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,17,self._ctx) + _la = self._input.LA(1) @@ -1303,9 +1568,8 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser - def ident(self): - return self.getTypedRuleContext(RelayParser.IdentContext,0) - + def LOCAL_VAR(self): + return self.getToken(RelayParser.LOCAL_VAR, 0) def type_(self): return self.getTypedRuleContext(RelayParser.Type_Context,0) @@ -1326,19 +1590,19 @@ def accept(self, visitor:ParseTreeVisitor): def var(self): localctx = RelayParser.VarContext(self, self._ctx, self.state) - self.enterRule(localctx, 14, self.RULE_var) + self.enterRule(localctx, 18, self.RULE_var) self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 205 - self.ident() - self.state = 208 + self.state = 226 + self.match(RelayParser.LOCAL_VAR) + self.state = 229 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==RelayParser.T__15: - self.state = 206 - self.match(RelayParser.T__15) - self.state = 207 + if _la==RelayParser.T__17: + self.state = 227 + self.match(RelayParser.T__17) + self.state = 228 self.type_() @@ -1350,7 +1614,7 @@ def var(self): self.exitRule() return localctx - class AttrListContext(ParserRuleContext): + class AttrSeqContext(ParserRuleContext): def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) @@ -1364,43 +1628,37 @@ def attr(self, i:int=None): def getRuleIndex(self): - return RelayParser.RULE_attrList + return RelayParser.RULE_attrSeq def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAttrList" ): - return visitor.visitAttrList(self) + if hasattr( visitor, "visitAttrSeq" ): + return visitor.visitAttrSeq(self) else: return visitor.visitChildren(self) - def attrList(self): + def attrSeq(self): - localctx = RelayParser.AttrListContext(self, self._ctx, self.state) - self.enterRule(localctx, 16, self.RULE_attrList) + localctx = RelayParser.AttrSeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_attrSeq) self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 218 + self.state = 231 + self.attr() + self.state = 236 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==RelayParser.CNAME: - self.state = 210 + while _la==RelayParser.T__0: + self.state = 232 + self.match(RelayParser.T__0) + self.state = 233 self.attr() - self.state = 215 + self.state = 238 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==RelayParser.T__2: - self.state = 211 - self.match(RelayParser.T__2) - self.state = 212 - self.attr() - self.state = 217 - self._errHandler.sync(self) - _la = self._input.LA(1) - - except RecognitionException as re: localctx.exception = re @@ -1438,14 +1696,14 @@ def accept(self, visitor:ParseTreeVisitor): def attr(self): localctx = RelayParser.AttrContext(self, self._ctx, self.state) - self.enterRule(localctx, 18, self.RULE_attr) + self.enterRule(localctx, 22, self.RULE_attr) try: self.enterOuterAlt(localctx, 1) - self.state = 220 + self.state = 239 self.match(RelayParser.CNAME) - self.state = 221 - self.match(RelayParser.T__8) - self.state = 222 + self.state = 240 + self.match(RelayParser.T__11) + self.state = 241 self.expr(0) except RecognitionException as re: localctx.exception = re @@ -1455,7 +1713,7 @@ def attr(self): self.exitRule() return localctx - class TypeParamSeqContext(ParserRuleContext): + class TypeParamListContext(ParserRuleContext): def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) @@ -1469,54 +1727,54 @@ def ident(self, i:int=None): def getRuleIndex(self): - return RelayParser.RULE_typeParamSeq + return RelayParser.RULE_typeParamList def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTypeParamSeq" ): - return visitor.visitTypeParamSeq(self) + if hasattr( visitor, "visitTypeParamList" ): + return visitor.visitTypeParamList(self) else: return visitor.visitChildren(self) - def typeParamSeq(self): + def typeParamList(self): - localctx = RelayParser.TypeParamSeqContext(self, self._ctx, self.state) - self.enterRule(localctx, 20, self.RULE_typeParamSeq) + localctx = RelayParser.TypeParamListContext(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_typeParamList) self._la = 0 # Token type try: - self.state = 237 + self.state = 256 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,23,self._ctx) + la_ = self._interp.adaptivePredict(self._input,24,self._ctx) if la_ == 1: self.enterOuterAlt(localctx, 1) - self.state = 224 - self.match(RelayParser.T__3) - self.state = 225 - self.match(RelayParser.T__4) + self.state = 243 + self.match(RelayParser.T__6) + self.state = 244 + self.match(RelayParser.T__7) pass elif la_ == 2: self.enterOuterAlt(localctx, 2) - self.state = 226 - self.match(RelayParser.T__3) - self.state = 227 + self.state = 245 + self.match(RelayParser.T__6) + self.state = 246 self.ident() - self.state = 232 + self.state = 251 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==RelayParser.T__2: - self.state = 228 - self.match(RelayParser.T__2) - self.state = 229 + while _la==RelayParser.T__0: + self.state = 247 + self.match(RelayParser.T__0) + self.state = 248 self.ident() - self.state = 234 + self.state = 253 self._errHandler.sync(self) _la = self._input.LA(1) - self.state = 235 - self.match(RelayParser.T__4) + self.state = 254 + self.match(RelayParser.T__7) pass @@ -1617,8 +1875,8 @@ def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type super().__init__(parser) self.copyFrom(ctx) - def shapeSeq(self): - return self.getTypedRuleContext(RelayParser.ShapeSeqContext,0) + def shapeList(self): + return self.getTypedRuleContext(RelayParser.ShapeListContext,0) def type_(self): return self.getTypedRuleContext(RelayParser.Type_Context,0) @@ -1643,8 +1901,8 @@ def type_(self, i:int=None): else: return self.getTypedRuleContext(RelayParser.Type_Context,i) - def typeParamSeq(self): - return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + def typeParamList(self): + return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) def accept(self, visitor:ParseTreeVisitor): @@ -1658,137 +1916,137 @@ def accept(self, visitor:ParseTreeVisitor): def type_(self): localctx = RelayParser.Type_Context(self, self._ctx, self.state) - self.enterRule(localctx, 22, self.RULE_type_) + self.enterRule(localctx, 26, self.RULE_type_) self._la = 0 # Token type try: - self.state = 284 + self.state = 303 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,28,self._ctx) + la_ = self._interp.adaptivePredict(self._input,29,self._ctx) if la_ == 1: localctx = RelayParser.TupleTypeContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 239 - self.match(RelayParser.T__0) - self.state = 240 + self.state = 258 self.match(RelayParser.T__1) + self.state = 259 + self.match(RelayParser.T__2) pass elif la_ == 2: localctx = RelayParser.TupleTypeContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 241 - self.match(RelayParser.T__0) - self.state = 242 + self.state = 260 + self.match(RelayParser.T__1) + self.state = 261 self.type_() - self.state = 243 + self.state = 262 + self.match(RelayParser.T__0) + self.state = 263 self.match(RelayParser.T__2) - self.state = 244 - self.match(RelayParser.T__1) pass elif la_ == 3: localctx = RelayParser.TupleTypeContext(self, localctx) self.enterOuterAlt(localctx, 3) - self.state = 246 - self.match(RelayParser.T__0) - self.state = 247 + self.state = 265 + self.match(RelayParser.T__1) + self.state = 266 self.type_() - self.state = 250 + self.state = 269 self._errHandler.sync(self) _la = self._input.LA(1) while True: - self.state = 248 - self.match(RelayParser.T__2) - self.state = 249 + self.state = 267 + self.match(RelayParser.T__0) + self.state = 268 self.type_() - self.state = 252 + self.state = 271 self._errHandler.sync(self) _la = self._input.LA(1) - if not (_la==RelayParser.T__2): + if not (_la==RelayParser.T__0): break - self.state = 254 - self.match(RelayParser.T__1) + self.state = 273 + self.match(RelayParser.T__2) pass elif la_ == 4: localctx = RelayParser.TypeIdentTypeContext(self, localctx) self.enterOuterAlt(localctx, 4) - self.state = 256 + self.state = 275 self.typeIdent() pass elif la_ == 5: localctx = RelayParser.TensorTypeContext(self, localctx) self.enterOuterAlt(localctx, 5) - self.state = 257 - self.match(RelayParser.T__16) - self.state = 258 - self.match(RelayParser.T__3) - self.state = 259 - self.shapeSeq() - self.state = 260 - self.match(RelayParser.T__2) - self.state = 261 + self.state = 276 + self.match(RelayParser.T__18) + self.state = 277 + self.match(RelayParser.T__6) + self.state = 278 + self.shapeList() + self.state = 279 + self.match(RelayParser.T__0) + self.state = 280 self.type_() - self.state = 262 - self.match(RelayParser.T__4) + self.state = 281 + self.match(RelayParser.T__7) pass elif la_ == 6: localctx = RelayParser.FuncTypeContext(self, localctx) self.enterOuterAlt(localctx, 6) - self.state = 264 - self.match(RelayParser.T__12) - self.state = 266 + self.state = 283 + self.match(RelayParser.T__14) + self.state = 285 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==RelayParser.T__3: - self.state = 265 - self.typeParamSeq() + if _la==RelayParser.T__6: + self.state = 284 + self.typeParamList() - self.state = 268 - self.match(RelayParser.T__0) - self.state = 277 + self.state = 287 + self.match(RelayParser.T__1) + self.state = 296 self._errHandler.sync(self) _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__12) | (1 << RelayParser.T__16) | (1 << RelayParser.T__17) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): - self.state = 269 + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__1) | (1 << RelayParser.T__14) | (1 << RelayParser.T__18) | (1 << RelayParser.T__19) | (1 << RelayParser.CNAME) | (1 << RelayParser.NAT))) != 0): + self.state = 288 self.type_() - self.state = 274 + self.state = 293 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==RelayParser.T__2: - self.state = 270 - self.match(RelayParser.T__2) - self.state = 271 + while _la==RelayParser.T__0: + self.state = 289 + self.match(RelayParser.T__0) + self.state = 290 self.type_() - self.state = 276 + self.state = 295 self._errHandler.sync(self) _la = self._input.LA(1) - self.state = 279 - self.match(RelayParser.T__1) - self.state = 280 - self.match(RelayParser.T__13) - self.state = 281 + self.state = 298 + self.match(RelayParser.T__2) + self.state = 299 + self.match(RelayParser.T__15) + self.state = 300 self.type_() pass elif la_ == 7: localctx = RelayParser.IncompleteTypeContext(self, localctx) self.enterOuterAlt(localctx, 7) - self.state = 282 - self.match(RelayParser.T__17) + self.state = 301 + self.match(RelayParser.T__19) pass elif la_ == 8: localctx = RelayParser.IntTypeContext(self, localctx) self.enterOuterAlt(localctx, 8) - self.state = 283 + self.state = 302 self.match(RelayParser.NAT) pass @@ -1801,7 +2059,7 @@ def type_(self): self.exitRule() return localctx - class ShapeSeqContext(ParserRuleContext): + class ShapeListContext(ParserRuleContext): def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) @@ -1815,71 +2073,117 @@ def shape(self, i:int=None): def getRuleIndex(self): - return RelayParser.RULE_shapeSeq + return RelayParser.RULE_shapeList def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitShapeSeq" ): - return visitor.visitShapeSeq(self) + if hasattr( visitor, "visitShapeList" ): + return visitor.visitShapeList(self) else: return visitor.visitChildren(self) - def shapeSeq(self): + def shapeList(self): - localctx = RelayParser.ShapeSeqContext(self, self._ctx, self.state) - self.enterRule(localctx, 24, self.RULE_shapeSeq) + localctx = RelayParser.ShapeListContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_shapeList) self._la = 0 # Token type try: - self.state = 303 + self.state = 318 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + la_ = self._interp.adaptivePredict(self._input,31,self._ctx) if la_ == 1: self.enterOuterAlt(localctx, 1) - self.state = 286 - self.match(RelayParser.T__0) - self.state = 287 - self.match(RelayParser.T__1) - pass - - elif la_ == 2: - self.enterOuterAlt(localctx, 2) - self.state = 288 - self.match(RelayParser.T__0) - self.state = 289 - self.shape() - self.state = 290 - self.match(RelayParser.T__2) - self.state = 291 + self.state = 305 self.match(RelayParser.T__1) - pass - - elif la_ == 3: - self.enterOuterAlt(localctx, 3) - self.state = 293 - self.match(RelayParser.T__0) - self.state = 294 + self.state = 306 self.shape() - self.state = 297 + self.state = 309 self._errHandler.sync(self) _la = self._input.LA(1) while True: - self.state = 295 - self.match(RelayParser.T__2) - self.state = 296 + self.state = 307 + self.match(RelayParser.T__0) + self.state = 308 self.shape() - self.state = 299 + self.state = 311 self._errHandler.sync(self) _la = self._input.LA(1) - if not (_la==RelayParser.T__2): + if not (_la==RelayParser.T__0): break - self.state = 301 + self.state = 313 + self.match(RelayParser.T__2) + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 315 self.match(RelayParser.T__1) + self.state = 316 + self.match(RelayParser.T__2) pass + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 317 + self.shape() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + class MetaContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def getRuleIndex(self): + return RelayParser.RULE_meta + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitMeta" ): + return visitor.visitMeta(self) + else: + return visitor.visitChildren(self) + + + + def meta(self): + + localctx = RelayParser.MetaContext(self, self._ctx, self.state) + self.enterRule(localctx, 30, self.RULE_meta) + try: + self.enterOuterAlt(localctx, 1) + self.state = 320 + self.match(RelayParser.T__20) + self.state = 321 + self.match(RelayParser.T__6) + self.state = 322 + self.match(RelayParser.CNAME) + self.state = 323 + self.match(RelayParser.T__7) + self.state = 324 + self.match(RelayParser.T__6) + self.state = 325 + self.match(RelayParser.NAT) + self.state = 326 + self.match(RelayParser.T__7) except RecognitionException as re: localctx.exception = re self._errHandler.reportError(self, re) @@ -1921,6 +2225,23 @@ def accept(self, visitor:ParseTreeVisitor): return visitor.visitChildren(self) + class MetaShapeContext(ShapeContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext + super().__init__(parser) + self.copyFrom(ctx) + + def meta(self): + return self.getTypedRuleContext(RelayParser.MetaContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitMetaShape" ): + return visitor.visitMetaShape(self) + else: + return visitor.visitChildren(self) + + class IntShapeContext(ShapeContext): def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext @@ -1941,25 +2262,31 @@ def accept(self, visitor:ParseTreeVisitor): def shape(self): localctx = RelayParser.ShapeContext(self, self._ctx, self.state) - self.enterRule(localctx, 26, self.RULE_shape) + self.enterRule(localctx, 32, self.RULE_shape) try: - self.state = 310 + self.state = 334 self._errHandler.sync(self) token = self._input.LA(1) - if token in [RelayParser.T__0]: - localctx = RelayParser.ParensShapeContext(self, localctx) + if token in [RelayParser.T__20]: + localctx = RelayParser.MetaShapeContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 305 - self.match(RelayParser.T__0) - self.state = 306 - self.shape() - self.state = 307 + self.state = 328 + self.meta() + pass + elif token in [RelayParser.T__1]: + localctx = RelayParser.ParensShapeContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 329 self.match(RelayParser.T__1) + self.state = 330 + self.shape() + self.state = 331 + self.match(RelayParser.T__2) pass elif token in [RelayParser.NAT]: localctx = RelayParser.IntShapeContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 309 + self.enterOuterAlt(localctx, 3) + self.state = 333 self.match(RelayParser.NAT) pass else: @@ -1997,10 +2324,10 @@ def accept(self, visitor:ParseTreeVisitor): def typeIdent(self): localctx = RelayParser.TypeIdentContext(self, self._ctx, self.state) - self.enterRule(localctx, 28, self.RULE_typeIdent) + self.enterRule(localctx, 34, self.RULE_typeIdent) try: self.enterOuterAlt(localctx, 1) - self.state = 312 + self.state = 336 self.match(RelayParser.CNAME) except RecognitionException as re: localctx.exception = re @@ -2035,15 +2362,15 @@ def accept(self, visitor:ParseTreeVisitor): def body(self): localctx = RelayParser.BodyContext(self, self._ctx, self.state) - self.enterRule(localctx, 30, self.RULE_body) + self.enterRule(localctx, 36, self.RULE_body) try: self.enterOuterAlt(localctx, 1) - self.state = 314 - self.match(RelayParser.T__10) - self.state = 315 + self.state = 338 + self.match(RelayParser.T__3) + self.state = 339 self.expr(0) - self.state = 316 - self.match(RelayParser.T__11) + self.state = 340 + self.match(RelayParser.T__4) except RecognitionException as re: localctx.exception = re self._errHandler.reportError(self, re) @@ -2120,27 +2447,27 @@ def accept(self, visitor:ParseTreeVisitor): def scalar(self): localctx = RelayParser.ScalarContext(self, self._ctx, self.state) - self.enterRule(localctx, 32, self.RULE_scalar) + self.enterRule(localctx, 38, self.RULE_scalar) try: - self.state = 321 + self.state = 345 self._errHandler.sync(self) token = self._input.LA(1) if token in [RelayParser.FLOAT]: localctx = RelayParser.ScalarFloatContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 318 + self.state = 342 self.match(RelayParser.FLOAT) pass elif token in [RelayParser.NAT]: localctx = RelayParser.ScalarIntContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 319 + self.state = 343 self.match(RelayParser.NAT) pass elif token in [RelayParser.BOOL_LIT]: localctx = RelayParser.ScalarBoolContext(self, localctx) self.enterOuterAlt(localctx, 3) - self.state = 320 + self.state = 344 self.match(RelayParser.BOOL_LIT) pass else: @@ -2188,29 +2515,29 @@ def accept(self, visitor:ParseTreeVisitor): def ident(self): localctx = RelayParser.IdentContext(self, self._ctx, self.state) - self.enterRule(localctx, 34, self.RULE_ident) + self.enterRule(localctx, 40, self.RULE_ident) try: - self.state = 327 + self.state = 351 self._errHandler.sync(self) token = self._input.LA(1) if token in [RelayParser.CNAME]: self.enterOuterAlt(localctx, 1) - self.state = 323 + self.state = 347 self.opIdent() pass elif token in [RelayParser.GLOBAL_VAR]: self.enterOuterAlt(localctx, 2) - self.state = 324 + self.state = 348 self.match(RelayParser.GLOBAL_VAR) pass elif token in [RelayParser.LOCAL_VAR]: self.enterOuterAlt(localctx, 3) - self.state = 325 + self.state = 349 self.match(RelayParser.LOCAL_VAR) pass elif token in [RelayParser.GRAPH_VAR]: self.enterOuterAlt(localctx, 4) - self.state = 326 + self.state = 350 self.match(RelayParser.GRAPH_VAR) pass else: @@ -2229,7 +2556,7 @@ def ident(self): def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): if self._predicates == None: self._predicates = dict() - self._predicates[2] = self.expr_sempred + self._predicates[4] = self.expr_sempred pred = self._predicates.get(ruleIndex, None) if pred is None: raise Exception("No predicate with index:" + str(ruleIndex)) @@ -2238,27 +2565,31 @@ def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): def expr_sempred(self, localctx:ExprContext, predIndex:int): if predIndex == 0: - return self.precpred(self._ctx, 16) + return self.precpred(self._ctx, 18) if predIndex == 1: - return self.precpred(self._ctx, 15) + return self.precpred(self._ctx, 17) if predIndex == 2: - return self.precpred(self._ctx, 14) + return self.precpred(self._ctx, 16) if predIndex == 3: - return self.precpred(self._ctx, 13) + return self.precpred(self._ctx, 15) if predIndex == 4: - return self.precpred(self._ctx, 4) + return self.precpred(self._ctx, 6) if predIndex == 5: - return self.precpred(self._ctx, 18) + return self.precpred(self._ctx, 20) + + + if predIndex == 6: + return self.precpred(self._ctx, 10) diff --git a/python/tvm/relay/grammar/py3/RelayVisitor.py b/python/tvm/relay/grammar/py3/RelayVisitor.py index 9e3631f5208b..3ea1287d7bcd 100644 --- a/python/tvm/relay/grammar/py3/RelayVisitor.py +++ b/python/tvm/relay/grammar/py3/RelayVisitor.py @@ -19,11 +19,51 @@ def visitProg(self, ctx:RelayParser.ProgContext): return self.visitChildren(ctx) + # Visit a parse tree produced by RelayParser#exprList. + def visitExprList(self, ctx:RelayParser.ExprListContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#callNoAttr. + def visitCallNoAttr(self, ctx:RelayParser.CallNoAttrContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#callWithAttr. + def visitCallWithAttr(self, ctx:RelayParser.CallWithAttrContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#funcExpr. + def visitFuncExpr(self, ctx:RelayParser.FuncExprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#metaExpr. + def visitMetaExpr(self, ctx:RelayParser.MetaExprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tensor. + def visitTensor(self, ctx:RelayParser.TensorContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#graph. + def visitGraph(self, ctx:RelayParser.GraphContext): + return self.visitChildren(ctx) + + # Visit a parse tree produced by RelayParser#identExpr. def visitIdentExpr(self, ctx:RelayParser.IdentExprContext): return self.visitChildren(ctx) + # Visit a parse tree produced by RelayParser#stringExpr. + def visitStringExpr(self, ctx:RelayParser.StringExprContext): + return self.visitChildren(ctx) + + # Visit a parse tree produced by RelayParser#call. def visitCall(self, ctx:RelayParser.CallContext): return self.visitChildren(ctx) @@ -39,13 +79,8 @@ def visitTuple(self, ctx:RelayParser.TupleContext): return self.visitChildren(ctx) - # Visit a parse tree produced by RelayParser#parens. - def visitParens(self, ctx:RelayParser.ParensContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#funcExpr. - def visitFuncExpr(self, ctx:RelayParser.FuncExprContext): + # Visit a parse tree produced by RelayParser#paren. + def visitParen(self, ctx:RelayParser.ParenContext): return self.visitChildren(ctx) @@ -59,8 +94,8 @@ def visitLet(self, ctx:RelayParser.LetContext): return self.visitChildren(ctx) - # Visit a parse tree produced by RelayParser#tensor. - def visitTensor(self, ctx:RelayParser.TensorContext): + # Visit a parse tree produced by RelayParser#projection. + def visitProjection(self, ctx:RelayParser.ProjectionContext): return self.visitChildren(ctx) @@ -69,11 +104,6 @@ def visitIfElse(self, ctx:RelayParser.IfElseContext): return self.visitChildren(ctx) - # Visit a parse tree produced by RelayParser#graph. - def visitGraph(self, ctx:RelayParser.GraphContext): - return self.visitChildren(ctx) - - # Visit a parse tree produced by RelayParser#binOp. def visitBinOp(self, ctx:RelayParser.BinOpContext): return self.visitChildren(ctx) @@ -89,8 +119,13 @@ def visitDefn(self, ctx:RelayParser.DefnContext): return self.visitChildren(ctx) - # Visit a parse tree produced by RelayParser#argList. - def visitArgList(self, ctx:RelayParser.ArgListContext): + # Visit a parse tree produced by RelayParser#argNoAttr. + def visitArgNoAttr(self, ctx:RelayParser.ArgNoAttrContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#argWithAttr. + def visitArgWithAttr(self, ctx:RelayParser.ArgWithAttrContext): return self.visitChildren(ctx) @@ -104,8 +139,8 @@ def visitVar(self, ctx:RelayParser.VarContext): return self.visitChildren(ctx) - # Visit a parse tree produced by RelayParser#attrList. - def visitAttrList(self, ctx:RelayParser.AttrListContext): + # Visit a parse tree produced by RelayParser#attrSeq. + def visitAttrSeq(self, ctx:RelayParser.AttrSeqContext): return self.visitChildren(ctx) @@ -114,8 +149,8 @@ def visitAttr(self, ctx:RelayParser.AttrContext): return self.visitChildren(ctx) - # Visit a parse tree produced by RelayParser#typeParamSeq. - def visitTypeParamSeq(self, ctx:RelayParser.TypeParamSeqContext): + # Visit a parse tree produced by RelayParser#typeParamList. + def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext): return self.visitChildren(ctx) @@ -149,8 +184,18 @@ def visitIntType(self, ctx:RelayParser.IntTypeContext): return self.visitChildren(ctx) - # Visit a parse tree produced by RelayParser#shapeSeq. - def visitShapeSeq(self, ctx:RelayParser.ShapeSeqContext): + # Visit a parse tree produced by RelayParser#shapeList. + def visitShapeList(self, ctx:RelayParser.ShapeListContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#meta. + def visitMeta(self, ctx:RelayParser.MetaContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#metaShape. + def visitMetaShape(self, ctx:RelayParser.MetaShapeContext): return self.visitChildren(ctx) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 1de86173040d..fb83032b30aa 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -66,34 +66,34 @@ def conv2d(data, weight : tvm.relay.Expr The weight expressions. - strides : tuple of int, optional + strides : Optional[Tuple[int]] The strides of convolution. - padding : tuple of int, optional + padding : Optional[Tuple[int]] The padding of convolution on both sides of inputs before convolution. - dilation : tuple of int, optional + dilation : Optional[Tuple[int]] Specifies the dilation rate to be used for dilated convolution. - groups : int, optional + groups : Optional[int] Number of groups for grouped convolution. - channels : int, optional + channels : Optional[int] Number of output channels of this convolution. - kernel_size : tuple of int, optional + kernel_size : Optional[Tuple[int]] The spatial of the convolution kernel. - data_layout : str, optional + data_layout : Optional[str] Layout of the input. - kernel_layout : str, optional + kernel_layout : Optional[str] Layout of the weight. - out_layout : str, optional + out_layout : Optional[str] Layout of the output, by default, out_layout is the same as data_layout - out_dtype : str, optional + out_dtype : Optional[str] Specifies the output data type for mixed precision conv2d. Returns @@ -691,8 +691,30 @@ def dropout(data, rate=0.5): result : tvm.relay.Expr The result of dropout """ - result = _make.dropout(data, rate) - return TupleWrapper(result, 2)[0] + return TupleWrapper(dropout_raw(data, rate), 2)[0] + + +def dropout_raw(data, rate=0.5): + """Applies the dropout operation to the input array. + + During training, each element of the input is set to zero with + probability ``p``. The whole array is rescaled by ``1/(1-p)`` + to keep the expected sum of the input unchanged. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + rate : float, optional (default=0.5) + The probability for an element to be reset to 0. + + Returns + ------- + result : tvm.relay.Expr + The result of dropout + """ + return _make.dropout(data, rate) def batch_norm(data, diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 9218cae3de66..0244debe7a8b 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -23,4 +23,7 @@ def fromtext(data, source_name=None): """Parse a Relay program.""" from tvm.relay import _parser - return _parser.fromtext(data, source_name) + x = _parser.fromtext(data + "\n", source_name) + if x is None: + raise Exception("cannot parse: ", data) + return x diff --git a/python/tvm/relay/testing/densenet.py b/python/tvm/relay/testing/densenet.py index f9b479153bfa..9818f446cf75 100644 --- a/python/tvm/relay/testing/densenet.py +++ b/python/tvm/relay/testing/densenet.py @@ -42,7 +42,7 @@ def _make_dense_block(data, num_layers, bn_size, growth_rate, index): layer_out = data for i in range(num_layers): layer_out = _make_dense_layer(layer_out, growth_rate, bn_size, - "(%s, %s)" % (index, i)) + "%s_%s" % (index, i)) return layer_out def _make_transition(data, num_output_features, index): diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 2f3b7e91aaf7..7e190fc405da 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -29,7 +29,7 @@ def __eq__(self, other): """Compare two Relay types for structural equivalence using alpha equivalence. """ - return bool(_make._type_alpha_equal(self, other)) + return bool(_make._alpha_equal(self, other)) def __ne__(self, other): return not self.__eq__(other) diff --git a/src/lang/reflection.cc b/src/lang/reflection.cc index bc3d2895b811..d6100f3e33e9 100644 --- a/src/lang/reflection.cc +++ b/src/lang/reflection.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2016 by Contributors + * Copyright (c) 2019 by Contributors * \file reflection.cc * \brief Utilities to save/load/construct TVM objects */ @@ -29,10 +29,11 @@ #include #include #include +#include #include #include +#include #include -#include "../common/base64.h" namespace dmlc { DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); @@ -44,227 +45,10 @@ ::dmlc::Registry* NodeFactoryReg::Registry() { return ::dmlc::Registry::Get(); } -inline std::string Type2String(const Type& t) { - return runtime::TVMType2String(Type2TVMType(t)); -} - - inline Type String2Type(std::string s) { return TVMType2Type(runtime::String2TVMType(s)); } -using runtime::Object; -using runtime::ObjectCell; - -// indexer to index all the ndoes -class NodeIndexer : public AttrVisitor { - public: - std::unordered_map node_index{{nullptr, 0}}; - std::vector node_list{nullptr}; - std::unordered_map tensor_index; - std::vector tensor_list; - std::unordered_map vm_obj_index; - std::vector vm_obj_list; - - void Visit(const char* key, double* value) final {} - void Visit(const char* key, int64_t* value) final {} - void Visit(const char* key, uint64_t* value) final {} - void Visit(const char* key, int* value) final {} - void Visit(const char* key, bool* value) final {} - void Visit(const char* key, std::string* value) final {} - void Visit(const char* key, void** value) final {} - void Visit(const char* key, Type* value) final {} - void Visit(const char* key, NodeRef* value) final { - MakeIndex(value->node_.get()); - } - - void Visit(const char* key, runtime::NDArray* value) final { - DLTensor* ptr = const_cast((*value).operator->()); - if (tensor_index.count(ptr)) return; - CHECK_EQ(tensor_index.size(), tensor_list.size()); - tensor_index[ptr] = tensor_list.size(); - tensor_list.push_back(ptr); - } - - void Visit(const char* key, Object* value) final { - ObjectCell* ptr = value->ptr_.get(); - if (vm_obj_index.count(ptr)) return; - CHECK_EQ(vm_obj_index.size(), vm_obj_list.size()); - vm_obj_index[ptr] = vm_obj_list.size(); - vm_obj_list.push_back(ptr); - } - - // make index of all the children of node - void MakeIndex(Node* node) { - if (node == nullptr) return; - if (node_index.count(node)) return; - CHECK_EQ(node_index.size(), node_list.size()); - node_index[node] = node_list.size(); - node_list.push_back(node); - - if (node->is_type()) { - ArrayNode* n = static_cast(node); - for (const auto& sp : n->data) { - MakeIndex(sp.get()); - } - } else if (node->is_type()) { - MapNode* n = static_cast(node); - for (const auto& kv : n->data) { - MakeIndex(kv.first.get()); - MakeIndex(kv.second.get()); - } - } else if (node->is_type()) { - StrMapNode* n = static_cast(node); - for (const auto& kv : n->data) { - MakeIndex(kv.second.get()); - } - } else { - node->VisitAttrs(this); - } - } -}; - -// use map so attributes are ordered. -using AttrMap = std::map; - -// A Node structure for JSON node. -struct JSONNode { - // The type key of the data - std::string type_key; - // The global key for global object - std::string global_key; - // the attributes - AttrMap attrs; - // container keys - std::vector keys; - // container data - std::vector data; - - void Save(dmlc::JSONWriter *writer) const { - writer->BeginObject(); - writer->WriteObjectKeyValue("type_key", type_key); - if (global_key.size() != 0) { - writer->WriteObjectKeyValue("global_key", global_key); - } - if (attrs.size() != 0) { - writer->WriteObjectKeyValue("attrs", attrs); - } - if (keys.size() != 0) { - writer->WriteObjectKeyValue("keys", keys); - } - if (data.size() != 0) { - writer->WriteObjectKeyValue("data", data); - } - writer->EndObject(); - } - - void Load(dmlc::JSONReader *reader) { - attrs.clear(); - data.clear(); - global_key.clear(); - type_key.clear(); - dmlc::JSONObjectReadHelper helper; - helper.DeclareOptionalField("type_key", &type_key); - helper.DeclareOptionalField("global_key", &global_key); - helper.DeclareOptionalField("attrs", &attrs); - helper.DeclareOptionalField("keys", &keys); - helper.DeclareOptionalField("data", &data); - helper.ReadAllFields(reader); - } -}; - -class JSONAttrGetter : public AttrVisitor { - public: - const std::unordered_map* node_index_; - const std::unordered_map* tensor_index_; - const std::unordered_map* vm_obj_index_; - JSONNode* node_; - - void Visit(const char* key, double* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, int64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, uint64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, int* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, bool* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, std::string* value) final { - node_->attrs[key] = *value; - } - void Visit(const char* key, void** value) final { - LOG(FATAL) << "not allowed to serialize a pointer"; - } - void Visit(const char* key, Type* value) final { - node_->attrs[key] = Type2String(*value); - } - void Visit(const char* key, NodeRef* value) final { - node_->attrs[key] = std::to_string( - node_index_->at(value->node_.get())); - } - void Visit(const char* key, runtime::NDArray* value) final { - node_->attrs[key] = std::to_string( - tensor_index_->at(const_cast((*value).operator->()))); - } - void Visit(const char* key, Object* value) final { - node_->attrs[key] = std::to_string( - vm_obj_index_->at(value->ptr_.get())); - } - // Get the node - void Get(Node* node) { - if (node == nullptr) { - node_->type_key.clear(); - return; - } - node_->type_key = node->type_key(); - // sepcially handle global object - auto* f = dmlc::Registry::Find(node_->type_key); - CHECK(f != nullptr) - << "Node type \'" << node_->type_key << "\' is not registered in TVM"; - if (f->fglobal_key != nullptr) { - node_->global_key = f->fglobal_key(node); - return; - } - node_->attrs.clear(); - node_->data.clear(); - if (node->is_type()) { - ArrayNode* n = static_cast(node); - for (size_t i = 0; i < n->data.size(); ++i) { - node_->data.push_back( - node_index_->at(n->data[i].get())); - } - } else if (node->is_type()) { - MapNode* n = static_cast(node); - for (const auto& kv : n->data) { - node_->data.push_back( - node_index_->at(kv.first.get())); - node_->data.push_back( - node_index_->at(kv.second.get())); - } - } else if (node->is_type()) { - StrMapNode* n = static_cast(node); - for (const auto& kv : n->data) { - node_->keys.push_back(kv.first); - node_->data.push_back( - node_index_->at(kv.second.get())); - } - } else { - // do not need to recover content of global singleton object - // they are registered via the environment - auto* f = dmlc::Registry::Find(node->type_key()); - if (f != nullptr && f->fglobal_key != nullptr) return; - // recursively index normal object. - node->VisitAttrs(this); - } - } -}; - class JSONAttrSetter : public AttrVisitor { public: const std::vector >* node_list_; @@ -360,66 +144,6 @@ class JSONAttrSetter : public AttrVisitor { } }; -// json graph structure to store node -struct JSONGraph { - // the root of the graph - size_t root; - // the nodes of the graph - std::vector nodes; - // base64 b64ndarrays of arrays - std::vector b64ndarrays; - // global attributes - AttrMap attrs; - - void Save(dmlc::JSONWriter *writer) const { - writer->BeginObject(); - writer->WriteObjectKeyValue("root", root); - writer->WriteObjectKeyValue("nodes", nodes); - writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays); - if (attrs.size() != 0) { - writer->WriteObjectKeyValue("attrs", attrs); - } - writer->EndObject(); - } - - void Load(dmlc::JSONReader *reader) { - attrs.clear(); - dmlc::JSONObjectReadHelper helper; - helper.DeclareField("root", &root); - helper.DeclareField("nodes", &nodes); - helper.DeclareOptionalField("b64ndarrays", &b64ndarrays); - helper.DeclareOptionalField("attrs", &attrs); - helper.ReadAllFields(reader); - } - - static JSONGraph Create(const NodeRef& root) { - JSONGraph g; - NodeIndexer indexer; - indexer.MakeIndex(root.node_.get()); - JSONAttrGetter getter; - getter.node_index_ = &indexer.node_index; - getter.tensor_index_ = &indexer.tensor_index; - for (Node* n : indexer.node_list) { - JSONNode jnode; - getter.node_ = &jnode; - getter.Get(n); - g.nodes.emplace_back(std::move(jnode)); - } - g.attrs["tvm_version"] = TVM_VERSION; - g.root = indexer.node_index.at(root.node_.get()); - // serialize tensor - for (DLTensor* tensor : indexer.tensor_list) { - std::string blob; - dmlc::MemoryStringStream mstrm(&blob); - common::Base64OutStream b64strm(&mstrm); - runtime::SaveDLTensor(&b64strm, tensor); - b64strm.Finish(); - g.b64ndarrays.emplace_back(std::move(blob)); - } - return g; - } -}; - std::string SaveJSON(const NodeRef& n) { auto jgraph = JSONGraph::Create(n); std::ostringstream os; diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 5b8ef8ba89fd..11966ed3d8de 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file src/tvm/relay/ir/alpha_equal.cc * \brief Alpha equality check by deep comparing two nodes. */ @@ -27,9 +27,10 @@ #include #include #include +#include +#include #include "type_functor.h" #include "../../lang/attr_functor.h" - namespace tvm { namespace relay { @@ -40,8 +41,8 @@ class AlphaEqualHandler: public ExprFunctor, public PatternFunctor { public: - explicit AlphaEqualHandler(bool map_free_var) - : map_free_var_(map_free_var) { } + explicit AlphaEqualHandler(bool map_free_var, bool assert_mode) + : map_free_var_(map_free_var), assert_mode_(assert_mode) { } /*! * Check equality of two nodes. @@ -76,6 +77,9 @@ class AlphaEqualHandler: return AttrEqual(lhs, rhs); } + bool DoubleEqual(double l, double r) { + return true; + } /*! * Check equality of two attributes. * \param lhs The left hand operand. @@ -83,18 +87,28 @@ class AlphaEqualHandler: * \return The comparison result. */ bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) { - if (&lhs == &rhs) return true; - auto lhsd = lhs.as(); - if (lhsd) { - auto rhsd = lhs.as(); - if (!rhsd) return false; - if (lhsd->dict.size() != rhsd->dict.size()) return false; - for (const auto& k : lhsd->dict) { - if (!Equal(k.second, rhsd->dict[k.first])) return false; + auto compute = [&]() { + if (&lhs == &rhs) return true; + if (auto lhsd = lhs.as()) { + auto rhsd = lhs.as(); + if (!rhsd) return false; + if (lhsd->dict.size() != rhsd->dict.size()) return false; + for (const auto& k : lhsd->dict) { + if (!Equal(k.second, rhsd->dict[k.first])) return false; + } + return true; } - return true; - } - return AttrsEqualHandler::Equal(lhs, rhs); + if (auto lhsbn = lhs.as()) { + auto rhsbn = rhs.as(); + if (!rhsbn) return false; + return (lhsbn->axis == rhsbn->axis) + && DoubleEqual(lhsbn->epsilon, rhsbn->epsilon) + && (lhsbn->center == rhsbn->center) + && (lhsbn->scale == rhsbn->scale); + } + return AttrsEqualHandler::Equal(lhs, rhs); + }; + return Compare(compute(), lhs, rhs); } /*! * Check equality of two types. @@ -107,6 +121,13 @@ class AlphaEqualHandler: if (!lhs.defined() || !rhs.defined()) return false; return this->VisitType(lhs, rhs); } + + bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) { + if (assert_mode_) { + CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true); + } + return result; + } /*! * Check equality of two expressions. * @@ -120,18 +141,21 @@ class AlphaEqualHandler: * \return The comparison result. */ bool ExprEqual(const Expr& lhs, const Expr& rhs) { - if (lhs.same_as(rhs)) return true; - if (!lhs.defined() || !rhs.defined()) return false; - auto it = equal_map_.find(lhs); - if (it != equal_map_.end()) { - return it->second.same_as(rhs); - } - if (this->VisitExpr(lhs, rhs)) { - equal_map_[lhs] = rhs; - return true; - } else { - return false; - } + auto compute = [&]() { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() || !rhs.defined()) return false; + auto it = equal_map_.find(lhs); + if (it != equal_map_.end()) { + return it->second.same_as(rhs); + } + if (this->VisitExpr(lhs, rhs)) { + equal_map_[lhs] = rhs; + return true; + } else { + return false; + } + }; + return Compare(compute(), lhs, rhs); } protected: @@ -516,32 +540,41 @@ class AlphaEqualHandler: private: // whether to map open terms. bool map_free_var_; + // if in assert mode, must return true, and will throw error otherwise. + bool assert_mode_; // renaming of NodeRef to indicate two nodes equals to each other std::unordered_map equal_map_; }; bool AlphaEqual(const Type& lhs, const Type& rhs) { - return AlphaEqualHandler(false).TypeEqual(lhs, rhs); + return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs); } bool AlphaEqual(const Expr& lhs, const Expr& rhs) { - return AlphaEqualHandler(false).ExprEqual(lhs, rhs); + return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs); } // TODO(@jroesch): move to correct namespace? TVM_REGISTER_API("relay._make._alpha_equal") .set_body_typed([](NodeRef a, NodeRef b) { - return AlphaEqualHandler(false).Equal(a, b); + return AlphaEqualHandler(false, false).Equal(a, b); }); -TVM_REGISTER_API("relay._make._type_alpha_equal") -.set_body_typed([](Type a, Type b) { - return AlphaEqualHandler(false).TypeEqual(a, b); +TVM_REGISTER_API("relay._make._assert_alpha_equal") +.set_body_typed([](NodeRef a, NodeRef b) { + bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); + CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not alpha equal"; }); TVM_REGISTER_API("relay._make._graph_equal") .set_body_typed([](NodeRef a, NodeRef b) { - return AlphaEqualHandler(true).Equal(a, b); + return AlphaEqualHandler(true, false).Equal(a, b); +}); + +TVM_REGISTER_API("relay._make._assert_graph_equal") +.set_body_typed([](NodeRef a, NodeRef b) { + bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); + CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not graph equal"; }); } // namespace relay diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index f786ed7def6f..da6f8a5c3a27 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -89,7 +89,7 @@ std::string Doc::str() { return os.str(); } -Doc PrintVec(const std::vector& vec, const Doc& sep) { +Doc PrintSep(const std::vector& vec, const Doc& sep) { Doc seq; if (vec.size() != 0) { seq = vec[0]; diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index dc7e79b43b01..a7283125ae31 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -46,7 +46,12 @@ using DocAtom = std::shared_ptr; struct TextNode : DocAtomNode { std::string str; - explicit TextNode(const std::string& str) : str(str) {} + explicit TextNode(const std::string& str) : str(str) { + if (str.find_first_of("\t\n") != str.npos) { + LOG(FATAL) << "text node: '" << str << "' should not has tab or newline."; + throw; + } + } }; struct LineNode : DocAtomNode { @@ -91,8 +96,8 @@ class Doc { // DSL functions -// Render vectors of docs with a separator. e.g. PrintVec([1, 2, 3], f) -> 1f2f3 -Doc PrintVec(const std::vector& vec, const Doc& sep = Doc(", ")); +// Render vectors of docs with a separator. e.g. PrintSep([1, 2, 3], f) -> 1f2f3 +Doc PrintSep(const std::vector& vec, const Doc& sep = Doc(", ")); // Print a constant bool value. Doc PrintBool(bool value); // Print a data type. @@ -116,7 +121,8 @@ Doc PrintConstScalar(DataType dtype, const T* data) { } else if (dtype == Bool()) { return PrintBool(data[0] != 0); } else { - os << dtype << "(" << data[0] << ")"; + // todo(@M.K.) this is unsafe. fix. + os << data[0]; } return Doc(os.str()); } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 09196b49a617..e6510dec69b9 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -32,9 +32,11 @@ * - Otherwise, inline if the node is at the end of a scope and is used at most once. */ +#include #include #include #include +#include #include "doc.h" #include "type_functor.h" #include "../pass/dependency_graph.h" @@ -43,6 +45,17 @@ namespace tvm { namespace relay { +Doc Brace(const Doc& d, + const std::string& open = "{", + const std::string& close = "}", + int indent = 2) { + Doc doc; + doc << open; + doc << Indent(indent, PrintNewLine() << d) << PrintNewLine(); + doc << close; + return doc; +} + /*! * \brief Meta data context for PrettyPrinter. * @@ -108,8 +121,10 @@ class TextMetaDataContext { if (it != meta_repr_.end()) { return it->second; } + std::string type_key = node->type_key(); + CHECK(!type_key.empty()); Array& mvector = - meta_data_[node->type_key()]; + meta_data_[type_key]; int64_t index = static_cast(mvector.size()); mvector.push_back(node); Doc doc; @@ -117,14 +132,80 @@ class TextMetaDataContext { meta_repr_[node] = doc; return meta_repr_[node]; } + + Doc PrintKeyValue(const std::string& str, const Doc& v) const { + return Doc("\"") << str << "\": " << v; + } + + template + Doc PrintVector(const std::vector& vec) const { + std::vector docs; + for (const auto& t : vec) { + docs.push_back(Doc(t)); + } + return Doc("[") << PrintSep(docs) << "]"; + } + + Doc PrintAttrMap(const AttrMap& m) const { + std::vector docs; + for (const auto& p : m) { + docs.push_back(PrintKeyValue(p.first, PrintString(p.second))); + } + return Brace(PrintSep(docs, Doc(",") << PrintNewLine())); + } + + Doc PrintJSONNode(const JSONNode& n) const { + std::vector docs; + docs.push_back(PrintKeyValue("type_key", PrintString(n.type_key))); + if (!n.global_key.empty()) { + docs.push_back(PrintKeyValue("global_key", Doc(n.global_key))); + } + if (!n.attrs.empty()) { + docs.push_back(PrintKeyValue("attrs", PrintAttrMap(n.attrs))); + } + std::vector keys; + for (const auto& k : n.keys) { + keys.push_back(PrintString(k)); + } + if (!n.keys.empty()) { + docs.push_back(PrintKeyValue("keys", PrintVector(keys))); + } + if (!n.data.empty()) { + docs.push_back(PrintKeyValue("data", PrintVector(n.data))); + } + return Brace(PrintSep(docs, Doc(",") << PrintNewLine())); + } + + Doc PrintJSONGraph(const JSONGraph& j) const { + std::vector docs; + docs.push_back(PrintKeyValue("root", Doc(j.root))); + std::vector nodes; + for (const auto& node : j.nodes) { + nodes.push_back(PrintJSONNode(node)); + } + docs.push_back(PrintKeyValue("nodes", + Brace(PrintSep(nodes, Doc(",") << PrintNewLine()), + "[", + "]"))); + std::vector b64; + for (const auto& b : j.b64ndarrays) { + b64.push_back(PrintString(b)); + } + docs.push_back(PrintKeyValue("b64ndarrays", PrintVector(b64))); + if (!j.attrs.empty()) { + docs.push_back(PrintKeyValue("attrs", PrintAttrMap(j.attrs))); + } + return Brace(Doc(PrintSep(docs, Doc(",") << PrintNewLine()))); + } + /*! * \brief Get the metadata section in json format. * \return the meta data string. */ - std::string GetMetaSection() const { - if (meta_data_.size() == 0) return std::string(); - return SaveJSON(Map( - meta_data_.begin(), meta_data_.end())); + Doc GetMetaSection() const { + if (meta_data_.size() == 0) return Doc(); + auto m = Map(meta_data_.begin(), meta_data_.end()); + return PrintJSONGraph(JSONGraph::Create(m)); } /*! \return whether the meta data context is empty. */ @@ -172,12 +253,11 @@ class PrettyPrinter : } // indent a new body - // TODO(jmp): indent should be an instance variable of the printer Doc PrintBody(const NodeRef& node, int indent = 2) { Doc doc; Doc body; doc << "{"; - doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n"; + doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine(); doc << "}"; return doc; } @@ -203,13 +283,12 @@ class PrettyPrinter : Doc doc; doc << PrintScope(node); if (!meta_.empty()) { + doc << PrintNewLine(); if (show_meta_data_) { - std::string meta_json = meta_.GetMetaSection(); // append meta data in the end. - doc << "\n" << "/* meta data */" << "\n" << meta_json; + doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection(); } else { - doc << "\n" - << "// meta data omitted. you can use show_meta_data=True to include meta data"; + doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; } } return doc; @@ -361,7 +440,7 @@ class PrettyPrinter : // wrap GNFed let in brackets Doc body; printed_expr << "{"; - printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n"; + printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine(); printed_expr << "}"; } else { printed_expr = VisitExpr(expr); @@ -373,7 +452,7 @@ class PrettyPrinter : if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << "\n"; + doc_stack_.back() << "free_var " << printed_expr << PrintNewLine(); // Memoization is done in AllocVar. return memo_[expr]; } else if (inline_expr) { @@ -422,7 +501,7 @@ class PrettyPrinter : fields.push_back(Print(field)); } Doc doc; - doc << "(" << PrintVec(fields); + doc << "(" << PrintSep(fields); // conform to python tuple format (1,) if (op->fields.size() == 1) { doc << ","; @@ -460,31 +539,31 @@ class PrettyPrinter : } Doc PrintFunc(const Doc& prefix, const Function& fn) { - Doc doc; - doc << prefix; - if (fn->type_params.size() > 0) { - doc << "<"; - std::vector type_params; - for (const TypeVar& tv : fn->type_params) { - type_params.push_back(AllocTypeVar(tv)); - } - doc << PrintVec(type_params); - doc << ">"; - } - doc << "("; - std::vector params; - for (Var param : fn->params) { - params.push_back(AllocVar(param)); - } - for (const Doc& d : PrintFuncAttrs(fn->attrs)) { - params.push_back(d); - } - doc << PrintVec(params) << ") "; - if (fn->ret_type.defined()) { - doc << "-> " << Print(fn->ret_type) << " "; + Doc doc; + doc << prefix; + if (fn->type_params.size() > 0) { + doc << "<"; + std::vector type_params; + for (const TypeVar& tv : fn->type_params) { + type_params.push_back(AllocTypeVar(tv)); } - doc << PrintBody(fn->body); - return doc; + doc << PrintSep(type_params); + doc << ">"; + } + doc << "("; + std::vector params; + for (Var param : fn->params) { + params.push_back(AllocVar(param)); + } + for (const Doc& d : PrintFuncAttrs(fn->attrs)) { + params.push_back(d); + } + doc << PrintSep(params) << ") "; + if (fn->ret_type.defined()) { + doc << "-> " << Print(fn->ret_type) << " "; + } + doc << PrintBody(fn->body); + return doc; } Doc PrintMod(const Module& mod) { @@ -493,13 +572,13 @@ class PrettyPrinter : for (const auto& kv : mod->functions) { dg_ = DependencyGraph::Create(&arena_, kv.second); - std::ostringstream os; if (counter++ != 0) { - doc << "\n"; + doc << PrintNewLine(); } + std::ostringstream os; os << "def @" << kv.first->name_hint; doc << PrintFunc(Doc(os.str()), kv.second); - doc << "\n"; + doc << PrintNewLine(); } return doc; } @@ -528,7 +607,7 @@ class PrettyPrinter : args.push_back(d); } doc << Print(op->op); - return doc << "(" << PrintVec(args) << ")"; + return doc << "(" << PrintSep(args) << ")"; } Doc VisitExpr_(const RefCreateNode* op) final { @@ -558,7 +637,7 @@ class PrettyPrinter : clauses.push_back(clause_doc << Print(clause->lhs) << " -> " << Print(clause->rhs)); } - doc << Indent(2, body << "\n" << PrintVec(clauses, Doc("\n"))) << "\n"; + doc << Indent(2, body << PrintNewLine() << PrintSep(clauses, PrintNewLine())) << PrintNewLine(); doc << "}"; return doc; } @@ -570,7 +649,7 @@ class PrettyPrinter : for (const auto& pat : p->patterns) { pats.push_back(Print(pat)); } - return doc << PrintVec(pats) << ")"; + return doc << PrintSep(pats) << ")"; } Doc VisitPattern_(const PatternVarNode* pv) final { @@ -617,7 +696,7 @@ class PrettyPrinter : args.push_back(PrintType(t, false)); } doc << "["; - doc << PrintVec(args); + doc << PrintSep(args); doc << "]"; return doc; } @@ -633,11 +712,7 @@ class PrettyPrinter : for (NodeRef shape : node->shape) { shapes.push_back(PrintAttr(shape)); } - doc << PrintVec(shapes); - // conform to python tuple format (1,) - if (node->shape.size() == 1) { - doc << ","; - } + doc << PrintSep(shapes); return doc << "), " << PrintDType(node->dtype) << "]"; } @@ -647,7 +722,7 @@ class PrettyPrinter : fields.push_back(Print(field)); } Doc doc; - doc << "(" << PrintVec(fields); + doc << "(" << PrintSep(fields); // conform to python tuple format (1,) if (node->fields.size() == 1) { doc << ","; @@ -664,14 +739,14 @@ class PrettyPrinter : for (Type type_param : node->type_params) { type_params.push_back(Print(type_param)); } - doc << PrintVec(type_params); + doc << PrintSep(type_params); doc << ">"; } std::vector arg_types; for (Type arg_type : node->arg_types) { arg_types.push_back(Print(arg_type)); } - return doc << "(" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); + return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type); } Doc VisitType_(const RefTypeNode* node) final { @@ -710,7 +785,7 @@ class PrettyPrinter : for (NodePtr val : op->data) { arr_vals.push_back(PrintAttr(NodeRef(val))); } - doc << PrintVec(arr_vals); + doc << PrintSep(arr_vals); doc << "]"; return doc; } @@ -771,7 +846,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { } void Visit(const char* key, double* value) final { - PrintKV(key, *value); + Doc doc; + doc << key << "=" << *value << "f"; + docs->push_back(doc); } void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); @@ -843,7 +920,7 @@ std::string PrettyPrint_(const NodeRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { Doc doc; - doc << "v0.0.3" << "\n" + doc << "v0.0.3" << PrintNewLine() << PrettyPrinter(show_meta_data, annotate).PrintFinal(node); return doc.str(); } diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 999b14d02cfc..26742cd4bf8d 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import relay -from tvm.relay.analysis import alpha_equal +from tvm.relay.analysis import alpha_equal, assert_alpha_equal from nose.tools import nottest, raises from numpy import isclose from typing import Union @@ -60,12 +60,9 @@ "float16x4", } -def assert_alpha_equal(a, b): - if not alpha_equal(a, b): - raise Exception("lhs is: ", str(a), "rhs is: ", str(b)) - def roundtrip(expr): - assert_alpha_equal(relay.fromtext(str(expr)), expr) + x = relay.fromtext(str(expr)) + assert_alpha_equal(x, expr) def parse_text(code): @@ -112,6 +109,16 @@ def test_comments(): UNIT ) + assert parses_as( + """ + /* This is a block comment! + /*Block comment is recursive!*/ + */ + () + """, + UNIT + ) + def test_int_literal(): assert isinstance(parse_text("1"), relay.Constant) @@ -224,7 +231,7 @@ def test_let(): def test_seq(): assert parses_as( - "(); ()", + "();; ()", relay.Let( _, UNIT, @@ -538,7 +545,7 @@ def test_tensor_type(): ) assert parses_as( - "let %_ : Tensor[(1,), float32] = (); ()", + "let %_ : Tensor[(1), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), UNIT, diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 32e6cde3dde2..b55261cb5b58 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -15,14 +15,27 @@ # specific language governing permissions and limitations # under the License. import tvm +from tvm import relay import tvm.relay.testing import numpy as np -from tvm import relay +from tvm.relay import Expr +from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equal, free_vars do_print = [False] SEMVER = "v0.0.3\n" +def astext(p, graph_equal=False): + txt = p.astext() + if isinstance(p, Expr) and free_vars(p): + return txt + x = relay.fromtext(txt) + if graph_equal: + assert_graph_equal(x, p) + else: + assert_alpha_equal(x, p) + return txt + def show(text): if do_print[0]: print("---------------------------") @@ -35,8 +48,8 @@ def test_func(): z = relay.add(x, one) z = relay.add(z, z) f = relay.Function([x, y], z) - show(z.astext()) - show(f.astext()) + show(astext(z)) + show(astext(f)) def test_env(): @@ -47,7 +60,7 @@ def test_env(): f = relay.Function([x, y], z) env = relay.Module() env["myf"] = f - text = env.astext() + text = astext(env) assert "def @myf" in text assert "def @myf" in str(env) assert "add(%0, %0) /* ty=float32 */" in text @@ -65,7 +78,7 @@ def test_meta_data(): padding=(1, 1), channels=2) f = relay.Function([x, w], z) - text = f.astext() + text = astext(f, graph_equal=True) text_no_meta = str(f) assert "channels=2" in text assert "channels=2" in text_no_meta @@ -73,25 +86,22 @@ def test_meta_data(): assert "meta[Variable][0]" in text_no_meta assert "type_key" in text assert "type_key" not in text_no_meta - show(text) - show(f) - text = relay.const([1,2,3]).astext() + text = astext(relay.const([1,2,3])) assert "meta[relay.Constant][0]" in text - show(text) def test_call_attrs(): x = relay.var("x") # non default args z = relay.nn.softmax(x, axis=2) - assert "axis=2" in z.astext() + assert "axis=2" in astext(z) # default args z = relay.nn.softmax(x) - assert "softmax(%x)" in z.astext() + assert "softmax(%x)" in astext(z) # non default args z = relay.expand_dims(x, axis=2, num_newaxis=2) - assert "num_newaxis=2" in z.astext() + assert "num_newaxis=2" in astext(z) def test_let_if_scope(): @@ -111,68 +121,72 @@ def test_let_if_scope(): result = sb.get() f = relay.Function([x, y, cond], result) - text = f.astext() + text = astext(f) assert text.count("{") == 4 assert "%cond: bool" in text - show(f.astext()) + show(astext(f)) def test_variable_name(): # avoid pure number even if the namehint is pure number v1 = relay.var("1") - assert "%v1" in v1.astext() + assert "%v1" in astext(v1) def test_mlp(): net, params = tvm.relay.testing.mlp.get_workload(batch_size=1) - net.astext() + astext(net) def test_resnet(): net, params = tvm.relay.testing.resnet.get_workload(batch_size=1) - net.astext() + astext(net) def test_mobilenet(): net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1) - net.astext() + astext(net) def test_dqn(): net, params = tvm.relay.testing.dqn.get_workload(batch_size=1) - net.astext() + astext(net) def test_dcgan(): net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1) - net.astext() + astext(net) def test_lstm(): + net, params = tvm.relay.testing.lstm.get_workload(1, 1) + astext(net) + net, params = tvm.relay.testing.lstm.get_workload(4, 4) - net.astext() + astext(net) def test_inception_v3(): net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1) - net.astext() + astext(net) def test_squeezenet(): for version in ['1.0', '1.1']: net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) - net.astext() + astext(net) def test_vgg(): net, params = tvm.relay.testing.vgg.get_workload(batch_size=1) - net.astext() + astext(net) def test_densenet(): net, params = tvm.relay.testing.densenet.get_workload(batch_size=1) - net.astext() + astext(net) def test_call_node_order(): x = relay.var("x") y = relay.var("y") - assert relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]).astext() == SEMVER + \ + prog = relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]) + assert astext(prog) == SEMVER + \ ("%0 = fn (%y) {\n" " %y\n" "};\n" @@ -185,17 +199,25 @@ def test_call_node_order(): def test_let_inlining(): tup = relay.Tuple([relay.const(0), relay.const(0)]) x = relay.var("x") - assert relay.Let(x, tup, tup).astext() == SEMVER + \ + assert astext(relay.Let(x, tup, tup)) == SEMVER + \ ("%0 = (0, 0);\n" "let %x = %0;\n" "%0") - assert relay.Let(x, tup, x).astext() == SEMVER + \ + assert astext(relay.Let(x, tup, x)) == SEMVER + \ ("let %x = (0, 0);\n" "%x") +def test_zeros(): + x = relay.op.zeros([], "float32") + astext(x) + if __name__ == "__main__": do_print[0] = True + test_lstm() + test_zeros() + test_meta_data() + test_let_inlining() test_resnet() test_mobilenet() test_mlp() @@ -207,9 +229,7 @@ def test_let_inlining(): test_densenet() test_func() test_env() - test_meta_data() test_call_attrs() test_let_if_scope() test_variable_name() test_call_node_order() - test_let_inlining() From 1cd81bd8b42378e4890039b4928610973cc5ca3e Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sat, 13 Jul 2019 12:55:42 -0700 Subject: [PATCH 2/3] fix test --- python/tvm/relay/_parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 3c0077d4e078..3fb4b2342bd7 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -141,9 +141,9 @@ def __call__(self, args, attrs, type_args): "bool", ] -T = TypeVar("T") -Scope = Deque[Tuple[str, T]] -Scopes = Deque[Scope[T]] +T = ty.TypeVar("T") +# Scope = Deque[Tuple[str, T]] +# Scopes = Deque[Scope[T]] def lookup(scopes, name): # type: (Scopes[T], str) -> Optional[T] From 5d3d5d4b3a44ef87379a848b12661b3e77760d5c Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Thu, 18 Jul 2019 01:16:51 -0700 Subject: [PATCH 3/3] revert json changes --- include/tvm/json.h | 317 --------------------------- {include/tvm => src}/common/base64.h | 0 src/lang/reflection.cc | 282 +++++++++++++++++++++++- src/relay/ir/doc.h | 7 +- src/relay/ir/pretty_printer.cc | 65 +----- 5 files changed, 283 insertions(+), 388 deletions(-) delete mode 100644 include/tvm/json.h rename {include/tvm => src}/common/base64.h (100%) diff --git a/include/tvm/json.h b/include/tvm/json.h deleted file mode 100644 index 8f681f24abf9..000000000000 --- a/include/tvm/json.h +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file json.h - * \brief A representation of JSON - */ - -#ifndef TVM_JSON_H_ -#define TVM_JSON_H_ - -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { - -// use map so attributes are ordered. -using AttrMap = std::map; - -using runtime::Object; -using runtime::ObjectCell; - -inline std::string Type2String(const Type& t) { - return runtime::TVMType2String(Type2TVMType(t)); -} - -// indexer to index all the ndoes -class NodeIndexer : public AttrVisitor { - public: - std::unordered_map node_index{{nullptr, 0}}; - std::vector node_list{nullptr}; - std::unordered_map tensor_index; - std::vector tensor_list; - std::unordered_map vm_obj_index; - std::vector vm_obj_list; - - void Visit(const char* key, double* value) final {} - void Visit(const char* key, int64_t* value) final {} - void Visit(const char* key, uint64_t* value) final {} - void Visit(const char* key, int* value) final {} - void Visit(const char* key, bool* value) final {} - void Visit(const char* key, std::string* value) final {} - void Visit(const char* key, void** value) final {} - void Visit(const char* key, Type* value) final {} - void Visit(const char* key, NodeRef* value) final { - MakeIndex(value->node_.get()); - } - - void Visit(const char* key, runtime::NDArray* value) final { - DLTensor* ptr = const_cast((*value).operator->()); - if (tensor_index.count(ptr)) return; - CHECK_EQ(tensor_index.size(), tensor_list.size()); - tensor_index[ptr] = tensor_list.size(); - tensor_list.push_back(ptr); - } - - void Visit(const char* key, Object* value) final { - ObjectCell* ptr = value->ptr_.get(); - if (vm_obj_index.count(ptr)) return; - CHECK_EQ(vm_obj_index.size(), vm_obj_list.size()); - vm_obj_index[ptr] = vm_obj_list.size(); - vm_obj_list.push_back(ptr); - } - - // make index of all the children of node - void MakeIndex(Node* node) { - if (node == nullptr) return; - if (node_index.count(node)) return; - CHECK_EQ(node_index.size(), node_list.size()); - node_index[node] = node_list.size(); - node_list.push_back(node); - - if (node->is_type()) { - ArrayNode* n = static_cast(node); - for (const auto& sp : n->data) { - MakeIndex(sp.get()); - } - } else if (node->is_type()) { - MapNode* n = static_cast(node); - for (const auto& kv : n->data) { - MakeIndex(kv.first.get()); - MakeIndex(kv.second.get()); - } - } else if (node->is_type()) { - StrMapNode* n = static_cast(node); - for (const auto& kv : n->data) { - MakeIndex(kv.second.get()); - } - } else { - node->VisitAttrs(this); - } - } -}; - -// A Node structure for JSON node. -struct JSONNode { - // The type key of the data - std::string type_key; - // The global key for global object - std::string global_key; - // the attributes - AttrMap attrs; - // container keys - std::vector keys; - // container data - std::vector data; - - void Save(dmlc::JSONWriter *writer) const { - writer->BeginObject(); - writer->WriteObjectKeyValue("type_key", type_key); - if (global_key.size() != 0) { - writer->WriteObjectKeyValue("global_key", global_key); - } - if (attrs.size() != 0) { - writer->WriteObjectKeyValue("attrs", attrs); - } - if (keys.size() != 0) { - writer->WriteObjectKeyValue("keys", keys); - } - if (data.size() != 0) { - writer->WriteObjectKeyValue("data", data); - } - writer->EndObject(); - } - - void Load(dmlc::JSONReader *reader) { - attrs.clear(); - data.clear(); - global_key.clear(); - type_key.clear(); - dmlc::JSONObjectReadHelper helper; - helper.DeclareOptionalField("type_key", &type_key); - helper.DeclareOptionalField("global_key", &global_key); - helper.DeclareOptionalField("attrs", &attrs); - helper.DeclareOptionalField("keys", &keys); - helper.DeclareOptionalField("data", &data); - helper.ReadAllFields(reader); - } -}; - -class JSONAttrGetter : public AttrVisitor { - public: - const std::unordered_map* node_index_; - const std::unordered_map* tensor_index_; - const std::unordered_map* vm_obj_index_; - JSONNode* node_; - - void Visit(const char* key, double* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, int64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, uint64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, int* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, bool* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, std::string* value) final { - node_->attrs[key] = *value; - } - void Visit(const char* key, void** value) final { - LOG(FATAL) << "not allowed to serialize a pointer"; - } - void Visit(const char* key, Type* value) final { - node_->attrs[key] = Type2String(*value); - } - void Visit(const char* key, NodeRef* value) final { - node_->attrs[key] = std::to_string( - node_index_->at(value->node_.get())); - } - void Visit(const char* key, runtime::NDArray* value) final { - node_->attrs[key] = std::to_string( - tensor_index_->at(const_cast((*value).operator->()))); - } - void Visit(const char* key, Object* value) final { - node_->attrs[key] = std::to_string( - vm_obj_index_->at(value->ptr_.get())); - } - // Get the node - void Get(Node* node) { - if (node == nullptr) { - node_->type_key.clear(); - return; - } - node_->type_key = node->type_key(); - // sepcially handle global object - auto* f = dmlc::Registry::Find(node_->type_key); - CHECK(f != nullptr) - << "Node type \'" << node_->type_key << "\' is not registered in TVM"; - if (f->fglobal_key != nullptr) { - node_->global_key = f->fglobal_key(node); - return; - } - node_->attrs.clear(); - node_->data.clear(); - if (node->is_type()) { - ArrayNode* n = static_cast(node); - for (size_t i = 0; i < n->data.size(); ++i) { - node_->data.push_back( - node_index_->at(n->data[i].get())); - } - } else if (node->is_type()) { - MapNode* n = static_cast(node); - for (const auto& kv : n->data) { - node_->data.push_back( - node_index_->at(kv.first.get())); - node_->data.push_back( - node_index_->at(kv.second.get())); - } - } else if (node->is_type()) { - StrMapNode* n = static_cast(node); - for (const auto& kv : n->data) { - node_->keys.push_back(kv.first); - node_->data.push_back( - node_index_->at(kv.second.get())); - } - } else { - // do not need to recover content of global singleton object - // they are registered via the environment - auto* f = dmlc::Registry::Find(node->type_key()); - if (f != nullptr && f->fglobal_key != nullptr) return; - // recursively index normal object. - node->VisitAttrs(this); - } - } -}; - -// json graph structure to store node -struct JSONGraph { - // the root of the graph - size_t root; - // the nodes of the graph - std::vector nodes; - // base64 b64ndarrays of arrays - std::vector b64ndarrays; - // global attributes - AttrMap attrs; - - void Save(dmlc::JSONWriter *writer) const { - writer->BeginObject(); - writer->WriteObjectKeyValue("root", root); - writer->WriteObjectKeyValue("nodes", nodes); - writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays); - if (attrs.size() != 0) { - writer->WriteObjectKeyValue("attrs", attrs); - } - writer->EndObject(); - } - - void Load(dmlc::JSONReader *reader) { - attrs.clear(); - dmlc::JSONObjectReadHelper helper; - helper.DeclareField("root", &root); - helper.DeclareField("nodes", &nodes); - helper.DeclareOptionalField("b64ndarrays", &b64ndarrays); - helper.DeclareOptionalField("attrs", &attrs); - helper.ReadAllFields(reader); - } - - static JSONGraph Create(const NodeRef& root) { - JSONGraph g; - NodeIndexer indexer; - indexer.MakeIndex(root.node_.get()); - JSONAttrGetter getter; - getter.node_index_ = &indexer.node_index; - getter.tensor_index_ = &indexer.tensor_index; - for (Node* n : indexer.node_list) { - JSONNode jnode; - getter.node_ = &jnode; - getter.Get(n); - g.nodes.emplace_back(std::move(jnode)); - } - g.attrs["tvm_version"] = TVM_VERSION; - g.root = indexer.node_index.at(root.node_.get()); - // serialize tensor - for (DLTensor* tensor : indexer.tensor_list) { - std::string blob; - dmlc::MemoryStringStream mstrm(&blob); - common::Base64OutStream b64strm(&mstrm); - runtime::SaveDLTensor(&b64strm, tensor); - b64strm.Finish(); - g.b64ndarrays.emplace_back(std::move(blob)); - } - return g; - } -}; - -} // namespace tvm -#endif // TVM_JSON_H_ diff --git a/include/tvm/common/base64.h b/src/common/base64.h similarity index 100% rename from include/tvm/common/base64.h rename to src/common/base64.h diff --git a/src/lang/reflection.cc b/src/lang/reflection.cc index d6100f3e33e9..bc3d2895b811 100644 --- a/src/lang/reflection.cc +++ b/src/lang/reflection.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2016 by Contributors * \file reflection.cc * \brief Utilities to save/load/construct TVM objects */ @@ -29,11 +29,10 @@ #include #include #include -#include #include #include -#include #include +#include "../common/base64.h" namespace dmlc { DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); @@ -45,10 +44,227 @@ ::dmlc::Registry* NodeFactoryReg::Registry() { return ::dmlc::Registry::Get(); } +inline std::string Type2String(const Type& t) { + return runtime::TVMType2String(Type2TVMType(t)); +} + + inline Type String2Type(std::string s) { return TVMType2Type(runtime::String2TVMType(s)); } +using runtime::Object; +using runtime::ObjectCell; + +// indexer to index all the ndoes +class NodeIndexer : public AttrVisitor { + public: + std::unordered_map node_index{{nullptr, 0}}; + std::vector node_list{nullptr}; + std::unordered_map tensor_index; + std::vector tensor_list; + std::unordered_map vm_obj_index; + std::vector vm_obj_list; + + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, void** value) final {} + void Visit(const char* key, Type* value) final {} + void Visit(const char* key, NodeRef* value) final { + MakeIndex(value->node_.get()); + } + + void Visit(const char* key, runtime::NDArray* value) final { + DLTensor* ptr = const_cast((*value).operator->()); + if (tensor_index.count(ptr)) return; + CHECK_EQ(tensor_index.size(), tensor_list.size()); + tensor_index[ptr] = tensor_list.size(); + tensor_list.push_back(ptr); + } + + void Visit(const char* key, Object* value) final { + ObjectCell* ptr = value->ptr_.get(); + if (vm_obj_index.count(ptr)) return; + CHECK_EQ(vm_obj_index.size(), vm_obj_list.size()); + vm_obj_index[ptr] = vm_obj_list.size(); + vm_obj_list.push_back(ptr); + } + + // make index of all the children of node + void MakeIndex(Node* node) { + if (node == nullptr) return; + if (node_index.count(node)) return; + CHECK_EQ(node_index.size(), node_list.size()); + node_index[node] = node_list.size(); + node_list.push_back(node); + + if (node->is_type()) { + ArrayNode* n = static_cast(node); + for (const auto& sp : n->data) { + MakeIndex(sp.get()); + } + } else if (node->is_type()) { + MapNode* n = static_cast(node); + for (const auto& kv : n->data) { + MakeIndex(kv.first.get()); + MakeIndex(kv.second.get()); + } + } else if (node->is_type()) { + StrMapNode* n = static_cast(node); + for (const auto& kv : n->data) { + MakeIndex(kv.second.get()); + } + } else { + node->VisitAttrs(this); + } + } +}; + +// use map so attributes are ordered. +using AttrMap = std::map; + +// A Node structure for JSON node. +struct JSONNode { + // The type key of the data + std::string type_key; + // The global key for global object + std::string global_key; + // the attributes + AttrMap attrs; + // container keys + std::vector keys; + // container data + std::vector data; + + void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("type_key", type_key); + if (global_key.size() != 0) { + writer->WriteObjectKeyValue("global_key", global_key); + } + if (attrs.size() != 0) { + writer->WriteObjectKeyValue("attrs", attrs); + } + if (keys.size() != 0) { + writer->WriteObjectKeyValue("keys", keys); + } + if (data.size() != 0) { + writer->WriteObjectKeyValue("data", data); + } + writer->EndObject(); + } + + void Load(dmlc::JSONReader *reader) { + attrs.clear(); + data.clear(); + global_key.clear(); + type_key.clear(); + dmlc::JSONObjectReadHelper helper; + helper.DeclareOptionalField("type_key", &type_key); + helper.DeclareOptionalField("global_key", &global_key); + helper.DeclareOptionalField("attrs", &attrs); + helper.DeclareOptionalField("keys", &keys); + helper.DeclareOptionalField("data", &data); + helper.ReadAllFields(reader); + } +}; + +class JSONAttrGetter : public AttrVisitor { + public: + const std::unordered_map* node_index_; + const std::unordered_map* tensor_index_; + const std::unordered_map* vm_obj_index_; + JSONNode* node_; + + void Visit(const char* key, double* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, int64_t* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, uint64_t* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, int* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, bool* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, std::string* value) final { + node_->attrs[key] = *value; + } + void Visit(const char* key, void** value) final { + LOG(FATAL) << "not allowed to serialize a pointer"; + } + void Visit(const char* key, Type* value) final { + node_->attrs[key] = Type2String(*value); + } + void Visit(const char* key, NodeRef* value) final { + node_->attrs[key] = std::to_string( + node_index_->at(value->node_.get())); + } + void Visit(const char* key, runtime::NDArray* value) final { + node_->attrs[key] = std::to_string( + tensor_index_->at(const_cast((*value).operator->()))); + } + void Visit(const char* key, Object* value) final { + node_->attrs[key] = std::to_string( + vm_obj_index_->at(value->ptr_.get())); + } + // Get the node + void Get(Node* node) { + if (node == nullptr) { + node_->type_key.clear(); + return; + } + node_->type_key = node->type_key(); + // sepcially handle global object + auto* f = dmlc::Registry::Find(node_->type_key); + CHECK(f != nullptr) + << "Node type \'" << node_->type_key << "\' is not registered in TVM"; + if (f->fglobal_key != nullptr) { + node_->global_key = f->fglobal_key(node); + return; + } + node_->attrs.clear(); + node_->data.clear(); + if (node->is_type()) { + ArrayNode* n = static_cast(node); + for (size_t i = 0; i < n->data.size(); ++i) { + node_->data.push_back( + node_index_->at(n->data[i].get())); + } + } else if (node->is_type()) { + MapNode* n = static_cast(node); + for (const auto& kv : n->data) { + node_->data.push_back( + node_index_->at(kv.first.get())); + node_->data.push_back( + node_index_->at(kv.second.get())); + } + } else if (node->is_type()) { + StrMapNode* n = static_cast(node); + for (const auto& kv : n->data) { + node_->keys.push_back(kv.first); + node_->data.push_back( + node_index_->at(kv.second.get())); + } + } else { + // do not need to recover content of global singleton object + // they are registered via the environment + auto* f = dmlc::Registry::Find(node->type_key()); + if (f != nullptr && f->fglobal_key != nullptr) return; + // recursively index normal object. + node->VisitAttrs(this); + } + } +}; + class JSONAttrSetter : public AttrVisitor { public: const std::vector >* node_list_; @@ -144,6 +360,66 @@ class JSONAttrSetter : public AttrVisitor { } }; +// json graph structure to store node +struct JSONGraph { + // the root of the graph + size_t root; + // the nodes of the graph + std::vector nodes; + // base64 b64ndarrays of arrays + std::vector b64ndarrays; + // global attributes + AttrMap attrs; + + void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("root", root); + writer->WriteObjectKeyValue("nodes", nodes); + writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays); + if (attrs.size() != 0) { + writer->WriteObjectKeyValue("attrs", attrs); + } + writer->EndObject(); + } + + void Load(dmlc::JSONReader *reader) { + attrs.clear(); + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("root", &root); + helper.DeclareField("nodes", &nodes); + helper.DeclareOptionalField("b64ndarrays", &b64ndarrays); + helper.DeclareOptionalField("attrs", &attrs); + helper.ReadAllFields(reader); + } + + static JSONGraph Create(const NodeRef& root) { + JSONGraph g; + NodeIndexer indexer; + indexer.MakeIndex(root.node_.get()); + JSONAttrGetter getter; + getter.node_index_ = &indexer.node_index; + getter.tensor_index_ = &indexer.tensor_index; + for (Node* n : indexer.node_list) { + JSONNode jnode; + getter.node_ = &jnode; + getter.Get(n); + g.nodes.emplace_back(std::move(jnode)); + } + g.attrs["tvm_version"] = TVM_VERSION; + g.root = indexer.node_index.at(root.node_.get()); + // serialize tensor + for (DLTensor* tensor : indexer.tensor_list) { + std::string blob; + dmlc::MemoryStringStream mstrm(&blob); + common::Base64OutStream b64strm(&mstrm); + runtime::SaveDLTensor(&b64strm, tensor); + b64strm.Finish(); + g.b64ndarrays.emplace_back(std::move(blob)); + } + return g; + } +}; + std::string SaveJSON(const NodeRef& n) { auto jgraph = JSONGraph::Create(n); std::ostringstream os; diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index a7283125ae31..6a10b60bc700 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -48,8 +48,7 @@ struct TextNode : DocAtomNode { explicit TextNode(const std::string& str) : str(str) { if (str.find_first_of("\t\n") != str.npos) { - LOG(FATAL) << "text node: '" << str << "' should not has tab or newline."; - throw; + LOG(WARNING) << "text node: '" << str << "' should not has tab or newline."; } } }; diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index e6510dec69b9..0ee76dc4c9aa 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -36,7 +36,6 @@ #include #include #include -#include #include "doc.h" #include "type_functor.h" #include "../pass/dependency_graph.h" @@ -137,75 +136,13 @@ class TextMetaDataContext { return Doc("\"") << str << "\": " << v; } - template - Doc PrintVector(const std::vector& vec) const { - std::vector docs; - for (const auto& t : vec) { - docs.push_back(Doc(t)); - } - return Doc("[") << PrintSep(docs) << "]"; - } - - Doc PrintAttrMap(const AttrMap& m) const { - std::vector docs; - for (const auto& p : m) { - docs.push_back(PrintKeyValue(p.first, PrintString(p.second))); - } - return Brace(PrintSep(docs, Doc(",") << PrintNewLine())); - } - - Doc PrintJSONNode(const JSONNode& n) const { - std::vector docs; - docs.push_back(PrintKeyValue("type_key", PrintString(n.type_key))); - if (!n.global_key.empty()) { - docs.push_back(PrintKeyValue("global_key", Doc(n.global_key))); - } - if (!n.attrs.empty()) { - docs.push_back(PrintKeyValue("attrs", PrintAttrMap(n.attrs))); - } - std::vector keys; - for (const auto& k : n.keys) { - keys.push_back(PrintString(k)); - } - if (!n.keys.empty()) { - docs.push_back(PrintKeyValue("keys", PrintVector(keys))); - } - if (!n.data.empty()) { - docs.push_back(PrintKeyValue("data", PrintVector(n.data))); - } - return Brace(PrintSep(docs, Doc(",") << PrintNewLine())); - } - - Doc PrintJSONGraph(const JSONGraph& j) const { - std::vector docs; - docs.push_back(PrintKeyValue("root", Doc(j.root))); - std::vector nodes; - for (const auto& node : j.nodes) { - nodes.push_back(PrintJSONNode(node)); - } - docs.push_back(PrintKeyValue("nodes", - Brace(PrintSep(nodes, Doc(",") << PrintNewLine()), - "[", - "]"))); - std::vector b64; - for (const auto& b : j.b64ndarrays) { - b64.push_back(PrintString(b)); - } - docs.push_back(PrintKeyValue("b64ndarrays", PrintVector(b64))); - if (!j.attrs.empty()) { - docs.push_back(PrintKeyValue("attrs", PrintAttrMap(j.attrs))); - } - return Brace(Doc(PrintSep(docs, Doc(",") << PrintNewLine()))); - } - /*! * \brief Get the metadata section in json format. * \return the meta data string. */ Doc GetMetaSection() const { if (meta_data_.size() == 0) return Doc(); - auto m = Map(meta_data_.begin(), meta_data_.end()); - return PrintJSONGraph(JSONGraph::Create(m)); + return Doc(SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); } /*! \return whether the meta data context is empty. */