diff --git a/.gitignore b/.gitignore index 506e54d930671..66eb0cb4f8665 100644 --- a/.gitignore +++ b/.gitignore @@ -230,7 +230,3 @@ conda/pkg # nix files .envrc *.nix - -# antlr files -*.tokens -*.interp diff --git a/CMakeLists.txt b/CMakeLists.txt index 512c8f44e53b8..aa632c40e3992 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,6 @@ include(cmake/util/FindOpenCL.cmake) include(cmake/util/FindVulkan.cmake) include(cmake/util/FindLLVM.cmake) include(cmake/util/FindROCM.cmake) -include(cmake/util/FindANTLR.cmake) if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake) @@ -69,7 +68,6 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_TENSORRT "Build with TensorRT, must have CUDA and CUDNN enabled" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) -tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) @@ -324,7 +322,6 @@ include(cmake/modules/Metal.cmake) include(cmake/modules/ROCM.cmake) include(cmake/modules/LLVM.cmake) include(cmake/modules/Micro.cmake) -include(cmake/modules/ANTLR.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake deleted file mode 100644 index d3c1b42182537..0000000000000 --- a/cmake/modules/ANTLR.cmake +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -if(USE_ANTLR) - find_antlr(${USE_ANTLR}) - if(ANTLR4) - - set(RELAY_PARSER_DIR - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) - - set(RELAY_PARSER - ${RELAY_PARSER_DIR}/py3/RelayVisitor.py - ${RELAY_PARSER_DIR}/py3/RelayParser.py - ${RELAY_PARSER_DIR}/py3/RelayLexer.py) - - - # Generate ANTLR grammar for parsing. - add_custom_command(OUTPUT ${RELAY_PARSER} - COMMAND ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 - DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 - WORKING_DIRECTORY ${RELAY_PARSER_DIR}) - - add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) - else() - message(FATAL_ERROR "Can't find ANTLR4") - endif() -endif(USE_ANTLR) diff --git a/cmake/util/FindANTLR.cmake b/cmake/util/FindANTLR.cmake deleted file mode 100644 index 3e490187083eb..0000000000000 --- a/cmake/util/FindANTLR.cmake +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -####################################################### -# Enhanced version of find ANTLR. -# -# Usage: -# find_antlr(${USE_ANTLR}) -# -# - When USE_ANTLR=ON, use auto search by first trying to find antlr4 program, -# then trying to find antlr-*-complete.jar -# - When USE_ANTLR=/path/to/antlr-*-complete.jar, use provided jar -# -# Provide variables: -# - ANTLR4 -# -macro(find_antlr use_antlr) - set(JAVA_HOME $ENV{JAVA_HOME}) - if (NOT DEFINED JAVA_HOME) - # Hack to get system to search for Java itself. - message(STATUS "JAVA_HOME is not defined. Set it to ensure proper use") - set(JAVA_HOME "/usr") - endif() - if(MSVC) - set(JAVA_PROGRAM ${JAVA_HOME}/java.exe) - else() - set(JAVA_PROGRAM ${JAVA_HOME}/bin/java) - endif() - message(STATUS "Using Java at " ${JAVA_PROGRAM}) - - if (${use_antlr} STREQUAL "ON") - find_program(ANTLR4 antlr4) - if (NOT ANTLR4) - file(GLOB_RECURSE ANTLR4JAR - /usr/local/lib/antlr-*-complete.jar - /usr/local/Cellar/*antlr-*-complete.jar) - - # Get the first element of the list of antlr jars. - # Sort and reverse the list so the item selected is the highest - # version in lib or else in Cellar if no lib installation exists. - list(SORT ANTLR4JAR) - list(REVERSE ANTLR4JAR) - list(GET ANTLR4JAR 0 ANTLR4JAR) - - set(ANTLR4 ${JAVA_PROGRAM} -jar ${ANTLR4JAR}) - endif() - elseif(NOT ${use_antlr} STREQUAL "OFF") - set(ANTLR4 ${JAVA_PROGRAM} -jar ${use_antlr}) - endif() - message(STATUS "ANTLR4=${ANTLR4}") -endmacro(find_antlr) diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 828fff4e5fc62..df416d48ce091 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -60,9 +60,6 @@ ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh RUN bash /install/ubuntu_install_java.sh -COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh -RUN bash /install/ubuntu_install_antlr.sh - # Chisel deps for TSIM COPY install/ubuntu_install_chisel.sh /install/ubuntu_install_chisel.sh RUN bash /install/ubuntu_install_chisel.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 6233cc53211f7..7b8468ef73466 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -87,9 +87,6 @@ RUN bash /install/ubuntu_install_vulkan.sh COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh -COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh -RUN bash /install/ubuntu_install_antlr.sh - # NNPACK deps COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh RUN bash /install/ubuntu_install_nnpack.sh diff --git a/docker/Dockerfile.ci_wasm b/docker/Dockerfile.ci_wasm index 965bc01d22d89..85f942d57ca3b 100644 --- a/docker/Dockerfile.ci_wasm +++ b/docker/Dockerfile.ci_wasm @@ -33,9 +33,6 @@ RUN bash /install/ubuntu1804_install_llvm.sh COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh RUN bash /install/ubuntu_install_java.sh -COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh -RUN bash /install/ubuntu_install_antlr.sh - COPY install/ubuntu_install_nodejs.sh /install/ubuntu_install_nodejs.sh RUN bash /install/ubuntu_install_nodejs.sh diff --git a/docker/install/ubuntu_install_antlr.sh b/docker/install/ubuntu_install_antlr.sh deleted file mode 100755 index de713a6f6a327..0000000000000 --- a/docker/install/ubuntu_install_antlr.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -set -e -set -u -set -o pipefail - -cd /usr/local/lib -wget -q https://www.antlr.org/download/antlr-4.7.1-complete.jar -cd - diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 2eaf00e8fdd03..2ad55c0e521ec 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -21,4 +21,4 @@ set -u set -o pipefail # install libraries for python package on ubuntu -pip3 install pylint==1.9.4 six numpy pytest cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs requests Pillow packaging +pip3 install pylint==1.9.4 six numpy pytest cython decorator scipy tornado typed_ast pytest mypy orderedset attrs requests Pillow packaging diff --git a/docs/README.txt b/docs/README.txt index 281cafaeee893..87acd306b6b2c 100644 --- a/docs/README.txt +++ b/docs/README.txt @@ -42,12 +42,12 @@ You can run the following script to reproduce the CI sphinx pre-check stage. This script skips the tutorial executions and is useful for quickly check the content. ```bash -./tests/scrpts/task_sphinx_precheck.sh +./tests/scripts/task_sphinx_precheck.sh ``` The following script runs the full build which includes tutorial executions. You will need a gpu CI environment. ```bash -./tests/scrpts/task_python_docs.sh +./tests/scripts/task_python_docs.sh ``` diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index 26aec77e09e22..9fe6c5e0eb23a 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -213,13 +213,6 @@ like ``virtualenv``. pip3 install --user tornado psutil xgboost - * If you want to build tvm to compile a model, you must use Python 3 and run the following - - .. code:: bash - - sudo apt install antlr4 - pip3 install --user mypy orderedset antlr4-python3-runtime - Install Contrib Libraries ------------------------- diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 749274acbb961..7981d58b0ead0 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -353,7 +353,8 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization"; + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization." + << "If the key is defined check that its type matches the declared type."; throw AttrError(os.str()); } } diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 40f854b027c6e..95a1acb9412db 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -80,21 +80,29 @@ class Span; class SpanNode : public Object { public: /*! \brief The source name. */ - SourceName source; + SourceName source_name; /*! \brief The line number. */ int line; /*! \brief The column offset. */ int column; + /*! \brief The end line number. */ + int end_line; + /*! \brief The end column number. */ + int end_column; // override attr visitor void VisitAttrs(AttrVisitor* v) { - v->Visit("source", &source); + v->Visit("source_name", &source_name); v->Visit("line", &line); v->Visit("column", &column); + v->Visit("end_line", &end_line); + v->Visit("end_column", &end_column); } bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { - return equal(source, other->source) && equal(line, other->line) && equal(column, other->column); + return equal(source_name, other->source_name) && equal(line, other->line) && + equal(column, other->column) && equal(end_line, other->end_line) && + equal(end_column, other->end_column); } static constexpr const char* _type_key = "Span"; @@ -103,7 +111,10 @@ class SpanNode : public Object { class Span : public ObjectRef { public: - TVM_DLL Span(SourceName source, int lineno, int col_offset); + TVM_DLL Span(SourceName source_name, int line, int end_line, int column, int end_column); + + /*! \brief Merge two spans into one which captures the combined regions. */ + TVM_DLL Span Merge(const Span& other); TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; diff --git a/include/tvm/parser/parser.h b/include/tvm/parser/parser.h index 93803588383a0..5c1239b1f59e9 100644 --- a/include/tvm/parser/parser.h +++ b/include/tvm/parser/parser.h @@ -32,7 +32,7 @@ namespace tvm { namespace parser { -IRModule Parse(std::string file_name, std::string file_content); +IRModule ParseModule(std::string file_name, std::string file_content); } // namespace parser } // namespace tvm diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h new file mode 100644 index 0000000000000..cf926665e2185 --- /dev/null +++ b/include/tvm/parser/source_map.h @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file source_map.h + * \brief A map from source names to source code. + */ +#ifndef TVM_PARSER_SOURCE_MAP_H_ +#define TVM_PARSER_SOURCE_MAP_H_ + +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace parser { + +/*! \brief A program source in any language. + * + * Could represent the source from an ML framework or the internal + * source of a TVM program. + */ +struct Source { + /*! \brief The source name. */ + SourceName source_name; + + /*! \brief The raw source. */ + std::string source; + /*! \brief A mapping of line breaks into the raw source. */ + std::vector> line_map; + + /*! \brief An empty source. */ + Source() : source_name(), source(), line_map() {} + + /*! \brief Construct a source from a string. */ + TVM_DLL explicit Source(const SourceName& src_name, const std::string& source); + + TVM_DLL Source(const Source& source) + : source_name(source.source_name), source(source.source), line_map(source.line_map) {} + + /*! \brief Generate an error message at a specific line and column with the + * annotated message. + * + * The error is written directly to the `out` std::ostream. + * + * \param out The output ostream. + * \param span The span to report the error at. + * \param msg The message to attach. + * + */ + // TODO(@jroesch): replace the ostream with an interface for rendering errors. + TVM_DLL void ReportAt(std::ostream& out, const Span& span, const std::string& msg) const; +}; + +/*! + * \brief A mapping from a unique source name to source fragment. + */ +class SourceMap; +/*! + * \brief Stores locations in frontend source that generated a node. + */ +class SourceMapNode : public Object { + public: + /*! \brief The source mapping. */ + Map source_map; + + // override attr visitor + void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } + + bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const { + return equal(source_map, other->source_map); + } + + static constexpr const char* _type_key = "SourceMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapNode, Object); +}; + +class SourceMap : public ObjectRef { + public: + TVM_DLL SourceMap(Map source_map); + + TVM_DLL static SourceMap* Get(); + + TVM_DEFINE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); +}; + +} // namespace parser +} // namespace tvm + +#endif // TVM_PARSER_SOURCE_MAP_H_ diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index b2164ba8c1f79..37182abb26813 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -309,8 +309,9 @@ class Match : public Expr { * \param data the input being deconstructed. * \param clauses The clauses for matching. * \param complete Indicate if this match is complete. + * \param span The span of the expression. */ - TVM_DLL Match(Expr data, tvm::Array clauses, bool complete = true); + TVM_DLL Match(Expr data, tvm::Array clauses, bool complete = true, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode); }; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 3c156dfd74812..d0c9217958f02 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -88,8 +88,9 @@ class Constant : public Expr { /*! * \brief The constructor * \param data The data of the constant tensor. + * \param span The source span of the expression. */ - TVM_DLL explicit Constant(runtime::NDArray data); + TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); }; @@ -134,8 +135,9 @@ class Tuple : public Expr { /*! * \brief The constructor * \param fields The fields of a tuple. + * \param span The source span of the expression. */ - TVM_DLL explicit Tuple(tvm::Array fields); + TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode); }; @@ -188,10 +190,6 @@ class VarNode : public ExprNode { hash_reduce.FreeVarHashImpl(this); } - TVM_DLL static Var make(String name_hint, Type type_annotation); - - TVM_DLL static Var make(Id vid, Type type_annotation); - static constexpr const char* _type_key = "relay.Var"; TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); }; @@ -202,15 +200,18 @@ class Var : public Expr { * \brief The constructor * \param name_hint The name hint of a variable. * \param type_annotation The type annotation of a variable. + * \param span The source span of the expression. */ - TVM_DLL Var(String name_hint, Type type_annotation) : Var(Id(name_hint), type_annotation) {} + TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span()) + : Var(Id(name_hint), type_annotation, span) {} /*! * \brief The constructor * \param vid The unique id of a variable. * \param type_annotation The type annotation of a variable. + * \param span The source span of the expression. */ - TVM_DLL Var(Id vid, Type type_annotation); + TVM_DLL Var(Id vid, Type type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); }; @@ -295,9 +296,10 @@ class Call : public Expr { * \param args The arguments of the call. * \param attrs The attributes of the call node. * \param type_args The type arguments passed to a polymorphic function. + * \param span The source span of the expression. */ TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), - Array type_args = Array()); + Array type_args = Array(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); }; @@ -356,8 +358,9 @@ class Let : public Expr { * \param var The variable that is bound to. * \param value The value used to bind to the variable. * \param body The body of the let binding. + * \param span The source span of the expression. */ - TVM_DLL Let(Var var, Expr value, Expr body); + TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode); }; @@ -416,8 +419,9 @@ class If : public Expr { * \param cond The condition of a if node. * \param true_branch The fall through branch * \param false_branch The branch for execution when condition is false. + * \param span The source span of the expression. */ - TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch); + TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode); }; @@ -457,8 +461,9 @@ class TupleGetItem : public Expr { * \brief The constructor * \param tuple The tuple to get an element from. * \param index The index for extracting a value in the tuple. + * \param span The source span of the expression. */ - TVM_DLL TupleGetItem(Expr tuple, int index); + TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode); }; @@ -495,8 +500,9 @@ class RefCreate : public Expr { /*! * \brief The constructor * \param value The initial value of the reference. + * \param span The source span of the expression. */ - TVM_DLL explicit RefCreate(Expr value); + TVM_DLL explicit RefCreate(Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode); }; @@ -533,8 +539,9 @@ class RefRead : public Expr { /*! * \brief The constructor * \param ref The reference where to read data. + * \param span The source span of the expression. */ - TVM_DLL explicit RefRead(Expr ref); + TVM_DLL explicit RefRead(Expr ref, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode); }; @@ -565,8 +572,6 @@ class RefWriteNode : public ExprNode { hash_reduce(value); } - TVM_DLL static RefWrite make(Expr ref, Expr value); - static constexpr const char* _type_key = "relay.RefWrite"; TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode); }; @@ -577,8 +582,9 @@ class RefWrite : public Expr { * \brief The constructor * \param ref The reference where data is write to. * \param value The value to write. + * \param span The source span of the expression. */ - TVM_DLL RefWrite(Expr ref, Expr value); + TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode); }; diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 1189643c81813..c3d2f724b7369 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -164,6 +164,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { virtual void VisitType(const Type& t); virtual void VisitClause(const Clause& c); virtual void VisitPattern(const Pattern& c); + virtual void VisitSpan(const Span& span); protected: // Internal visiting counter diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index d52a66cdadeb3..db973b91f92aa 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -110,9 +110,10 @@ class Function : public BaseFunc { * \param ret_type The return type of the function. * \param ty_params The type parameters. * \param attrs Additional function attributes. + * \param span The span of the function. */ TVM_DLL Function(tvm::Array params, Expr body, Type ret_type, tvm::Array ty_params, - tvm::DictAttrs attrs = NullValue()); + tvm::DictAttrs attrs = NullValue(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); diff --git a/python/setup.py b/python/setup.py index 682589ef5e6f3..3205a7cfa5253 100644 --- a/python/setup.py +++ b/python/setup.py @@ -167,8 +167,7 @@ def get_package_data_files(): 'psutil', 'xgboost>=1.1.0', 'mypy', - 'orderedset', - 'antlr4-python3-runtime']}, + 'orderedset']}, packages=find_packages(), package_dir={'tvm': 'tvm'}, diff --git a/python/tvm/error.py b/python/tvm/error.py index b3502f6b0eada..9125448e30c98 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -121,3 +121,10 @@ class OpAttributeUnImplemented(OpError, NotImplementedError): "Attribute {} is not supported in operator {}".format( attr_name, op_name)) """ + +@register_error +class DiagnosticError(TVMError): + """Error diagnostics were reported during the execution of a pass. + + See the configured diagnostic renderer for detailed error information. + """ diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index bab98382e713e..b505a2ee00bb9 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -84,9 +84,9 @@ class Span(Object): col_offset : int The column offset of the location. """ - def __init__(self, source, lineno, col_offset): + def __init__(self, source_name, line, end_line, column, end_column): self.__init_handle_by_constructor__( - _ffi_api.Span, source, lineno, col_offset) + _ffi_api.Span, source_name, line, end_line, column, end_column) @tvm._ffi.register_object diff --git a/python/tvm/parser/__init__.py b/python/tvm/parser/__init__.py index 071c464dae516..8001cd4167812 100644 --- a/python/tvm/parser/__init__.py +++ b/python/tvm/parser/__init__.py @@ -24,4 +24,4 @@ def parse_expr(source): return _ffi_api.ParseExpr("string", source) def fromtext(source, source_name="from_string"): - return parse(str(source), str(source_name)) + return parse(source, source_name) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index e3909d9d63781..cd96ecc7ee33f 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -29,7 +29,6 @@ from . import prelude from . import loops from . import scope_builder -from . import parser from . import transform from . import analysis @@ -132,12 +131,9 @@ # Prelude Prelude = prelude.Prelude -# Scope builder +# Scope Builder ScopeBuilder = scope_builder.ScopeBuilder -# Parser -fromtext = parser.fromtext - # Param Serialization save_param_dict = param_dict.save_param_dict load_param_dict = param_dict.load_param_dict diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py deleted file mode 100644 index 0d3f86f6262df..0000000000000 --- a/python/tvm/relay/_parser.py +++ /dev/null @@ -1,771 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=invalid-name, unused-argument -"""A parser for Relay's text format.""" -from __future__ import absolute_import - -import sys -from ast import literal_eval -from collections import deque - -try: - # no typing.Deque in Python 3.5 - # https://bugs.python.org/issue29011 - from typing import Any, Dict, List, Optional, TypeVar, Tuple, Union, MutableSequence, T, Deque -except ImportError: - class Deque(deque, MutableSequence[T], extra=deque): - - def __new__(cls, *args, **kwds): - if _geqv(cls, Deque): - raise TypeError("Type Deque cannot be instantiated; " - "use deque() instead") - return deque.__new__(cls, *args, **kwds) - -import tvm -import tvm.ir._ffi_api -from tvm.ir import IRModule - -from .base import Span, SourceName -from . import adt -from . import expr -from . import function -from . import ty -from . import op - -PYTHON_VERSION = sys.version_info.major -try: - from antlr4 import InputStream, CommonTokenStream - from antlr4.error.ErrorListener import ErrorListener -except ImportError: - raise Exception("Couldn't find ANTLR runtime." + - "Try running `pip{version} install antlr4-python{version}-runtime`." - .format(version=PYTHON_VERSION)) - -try: - from .grammar.py3.RelayVisitor import RelayVisitor - from .grammar.py3.RelayParser import RelayParser - from .grammar.py3.RelayLexer import RelayLexer -except ImportError: - raise Exception("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") - - -sys.setrecursionlimit(10000) - -class ParseError(Exception): - """Exception type for parse errors.""" - - def __init__(self, message: str) -> None: - super(ParseError, self).__init__() - self.message = message - - def __repr__(self): - return "ParseError({})".format(self.message) - - def __str__(self): - return repr(self) - -class OpWrapper: - """Overload the __call__ for op.""" - - -class ExprOp(OpWrapper): - """Call an expr. The default, but does not handle attrs well.""" - def __init__(self, operator): - self.operator = operator - - def __call__(self, args, attrs, type_args): - try: - return expr.Call(self.operator, args, attrs, type_args) - except Exception: - raise Exception("Operator {} is not registered. It's attributes are {}" - .format(self.operator, attrs)) - -class FuncOp(OpWrapper): - """Convert the attrs, call the python function with the attrs passed in as keyword arguments. - Tvm should provide this in the future, as this is pretty similar to what op.get is providing. - """ - def __init__(self, operator): - self.operator = operator - - def convert(self, v): - if isinstance(v, tuple): - return tuple([self.convert(x) for x in v]) - if isinstance(v, expr.Constant): - return v.data.asnumpy().item() - if isinstance(v, str): - return v - raise Exception(v) - - def __call__(self, args, attrs, type_args): - if attrs is None: - attrs = {} - if self.operator in (op.strided_slice,): - x = self.operator(*args) - else: - x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()}) - if isinstance(x, expr.TupleWrapper): - x = x.astuple() - return x - -BINARY_OPS = { - RelayParser.MUL: op.multiply, - RelayParser.DIV: op.divide, - RelayParser.ADD: op.add, - RelayParser.SUB: op.subtract, - RelayParser.LT: op.less, - RelayParser.GT: op.greater, - RelayParser.LE: op.less_equal, - RelayParser.GE: op.greater_equal, - RelayParser.EQ: op.equal, - RelayParser.NE: op.not_equal, -} - -FUNC_OPS = { - "nn.conv2d": op.nn.conv2d, - "nn.batch_norm": op.nn.batch_norm, - "nn.dense": op.nn.dense, - "nn.bias_add": op.nn.bias_add, - "nn.max_pool2d": op.nn.max_pool2d, - "nn.max_pool3d": op.nn.max_pool3d, - "nn.global_max_pool2d": op.nn.global_max_pool2d, - "nn.avg_pool2d": op.nn.avg_pool2d, - "nn.avg_pool3d": op.nn.avg_pool3d, - "nn.global_avg_pool2d": op.nn.global_avg_pool2d, - "nn.softmax": op.nn.softmax, - "reshape": op.reshape, - "nn.conv2d_transpose": op.nn.conv2d_transpose, - "nn.conv1d_transpose": op.nn.conv1d_transpose, - "concatenate": op.concatenate, - "nn.dropout": op.nn.dropout_raw, - "zeros": op.zeros, - "split": op.split, - "cast": op.cast, - "clip": op.clip, - "right_shift": op.right_shift, -} - -TYPE_PREFIXES = [ - "int", - "uint", - "float", - "bool", -] - -T = TypeVar("T") -Scope = Deque[Tuple[str, T]] -Scopes = Deque[Scope[T]] - -def lookup(scopes: Scopes[T], name: str) -> Optional[T]: - """Look up `name` in `scopes`.""" - - for scope in scopes: - for key, val in scope: - if key == name: - return val - return None - -def spanify(f): - """A decorator which attaches span information - to the value returned by calling `f`. - - Intended for use with the below AST visiting - methods. The idea is that after we do the work - of constructing the AST we attach Span information. - """ - - def _wrapper(*args, **kwargs): - # Assumes 0th arg is self and gets source_name from object. - sn = args[0].source_name - # Assumes 1st arg is an ANTLR parser context. - ctx = args[1] - ast = f(*args, **kwargs) - line, col = ctx.getSourceInterval() - sp = Span(sn, line, col) - if isinstance(ast, tvm.relay.expr.TupleWrapper): - ast = ast.astuple() - tvm.ir._ffi_api.NodeSetSpan(ast, sp) - return ast - return _wrapper - -# TODO(@jmp): Use https://stackoverflow.com/q/13889941 -# to figure out how to get ANTLR4 to be more unhappy about syntax errors -class ParseTreeToRelayIR(RelayVisitor): - """Parse Relay text format into Relay IR.""" - - def __init__(self, source_name: str) -> None: - self.source_name = source_name - self.module = IRModule({}) # type: IRModule - - # Adding an empty scope allows naked lets without pain. - self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] - self.global_vars = {} # type: Scope[expr.GlobalVar] - self.type_var_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] - self.global_type_vars = {} # type: Scope[expr.GlobalVar] - self.graph_expr = [] # type: List[expr.Expr] - - super(ParseTreeToRelayIR, self).__init__() - - - def enter_var_scope(self) -> None: - """Enter a new Var scope so it can be popped off later.""" - self.var_scopes.appendleft(deque()) - - def exit_var_scope(self) -> Scope[expr.Var]: - """Pop off the current Var scope and return it.""" - return self.var_scopes.popleft() - - def mk_var(self, name: str, typ: ty.Type = None): - """Create a new Var and add it to the Var scope.""" - var = expr.Var(name, typ) - self.var_scopes[0].appendleft((name, var)) - return var - - def mk_global_var(self, name: str) -> expr.GlobalVar: - """Create a new GlobalVar and add it to the GlobalVar scope.""" - if name in self.global_vars: - raise ParseError("duplicate global var \"{0}\"".format(name)) - var = expr.GlobalVar(name) - self.global_vars[name] = var - return var - - def enter_type_param_scope(self) -> None: - """Enter a new TypeVar scope so it can be popped off later.""" - self.type_var_scopes.appendleft(deque()) - - def exit_type_param_scope(self) -> Scope[ty.TypeVar]: - """Pop off the current TypeVar scope and return it.""" - return self.type_var_scopes.popleft() - - def mk_typ(self, name: str, kind: ty.TypeKind) -> ty.TypeVar: - """Create a new TypeVar and add it to the TypeVar scope.""" - typ = ty.TypeVar(name, kind) - self.type_var_scopes[0].append((name, typ)) - return typ - - def mk_global_typ_var(self, name, kind): - # (str, ty.Kind) -> ty.GlobalTypeVar - """Create a new TypeVar and add it to the TypeVar scope.""" - typ = ty.GlobalTypeVar(name, kind) - self._check_existing_typ_expr(name, typ) - self.global_type_vars[name] = typ - return typ - - # TODO(weberlo): rethink whether we should have type constructors mixed with type vars. - def mk_global_typ_cons(self, name, cons): - self._check_existing_typ_expr(name, cons) - self.global_type_vars[name] = cons - - def _check_existing_typ_expr(self, name, new_expr): - if name in self.global_type_vars: - new_typ_name = self._type_expr_name(new_expr) - existing_typ_name = self._type_expr_name(self.global_type_vars[name]) - raise ParseError( - "{0} `{1}` conflicts with existing {2}".format(new_typ_name,\ - name, existing_typ_name)) - - def _type_expr_name(self, e): - if isinstance(e, adt.Constructor): - return "`{0}` ADT constructor".format(e.belong_to.name_hint) - if isinstance(e, ty.GlobalTypeVar): - if e.kind == ty.TypeKind.AdtHandle: - return "ADT definition" - return "function definition" - - def visitProjection(self, ctx): - return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT())) - - def visitTerminal(self, node) -> Union[expr.Expr, int, float]: - """Visit lexer tokens that aren't ignored or visited by other functions.""" - node_type = node.getSymbol().type - node_text = node.getText() - - if node_type == RelayLexer.NAT: - return int(node_text) - if node_type == RelayLexer.FLOAT: - return float(node_text[:-1]) - if node_type == RelayLexer.BOOL_LIT: - if node_text == "True": - return True - if node_text == "False": - return False - raise ParseError("unrecognized BOOL_LIT: `{}`".format(node_text)) - if node_type == RelayLexer.QUOTED_STRING: - return literal_eval(node_text) - raise ParseError("unhandled terminal \"{0}\" of type `{1}`".format(node_text, node_type)) - - def visitGeneralIdent(self, ctx): - name = ctx.getText() - # Look through all type prefixes for a match. - for type_prefix in TYPE_PREFIXES: - if name.startswith(type_prefix): - return ty.scalar_type(name) - # Next, look it up in the local then global type params. - type_expr = lookup(self.type_var_scopes, name) - if type_expr is None: - type_expr = self.global_type_vars.get(name, None) - if type_expr is not None: - # Zero-arity constructor calls fall into the general ident case, so in that case, - # we construct a constructor call with no args. - if isinstance(type_expr, adt.Constructor) and not type_expr.inputs: - type_expr = expr.Call(type_expr, []) - return type_expr - # Check if it's an operator. - op_name = ".".join([name.getText() for name in ctx.CNAME()]) - if op_name in FUNC_OPS: - return FuncOp(FUNC_OPS[op_name]) - return ExprOp(op.get(op_name)) - - def visitGlobalVar(self, ctx): - var_name = ctx.CNAME().getText() - global_var = self.global_vars.get(var_name, None) - if global_var is None: - raise ParseError("unbound global var `{0}`".format(var_name)) - return global_var - - def visitLocalVar(self, ctx): - var_name = ctx.CNAME().getText() - local_var = lookup(self.var_scopes, var_name) - if local_var is None: - raise ParseError("unbound local var `{0}`".format(var_name)) - return local_var - - def visitGraphVar(self, ctx): - return self.graph_expr[int(ctx.NAT().getText())] - - def visit_list(self, ctx_list) -> List[Any]: - """"Visit a list of contexts.""" - assert isinstance(ctx_list, list) - - return [self.visit(ctx) for ctx in ctx_list] - - def getTypeExpr(self, ctx: Optional[RelayParser.TypeExprContext]) -> Optional[ty.Type]: - """Return a (possibly None) Relay type.""" - if ctx is None: - return None - - return self.visit(ctx) - - def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, IRModule]: - self.meta = None - if ctx.METADATA(): - header, data = str(ctx.METADATA()).split("\n", 1) - assert header == "METADATA:" - self.meta = tvm.ir.load_json(data) - if ctx.defn(): - self.visit_list(ctx.defn()) - return self.module - - if ctx.expr(): - return self.visit(ctx.expr()) - - return self.module - - # Exprs - def visitOpIdent(self, ctx) -> tvm.ir.Op: - op_name = ".".join([name.getText() for name in ctx.CNAME()]) - if op_name in FUNC_OPS: - return FuncOp(FUNC_OPS[op_name]) - return ExprOp(op.get(op_name)) - - # pass through - def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr: - return self.visit(ctx.expr()) - - # pass through - def visitTypeParen(self, ctx: RelayParser.TypeParenContext) -> expr.Expr: - return self.visit(ctx.typeExpr()) - - # pass through - def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr: - return self.visit(ctx.expr()) - - def visitScalarFloat(self, ctx: RelayParser.ScalarFloatContext) -> expr.Constant: - return expr.const(self.visit(ctx.FLOAT())) - - def visitScalarInt(self, ctx: RelayParser.ScalarIntContext) -> expr.Constant: - return expr.const(self.visit(ctx.NAT())) - - def visitScalarBool(self, ctx: RelayParser.ScalarBoolContext) -> expr.Constant: - return expr.const(self.visit(ctx.BOOL_LIT())) - - def visitNeg(self, ctx: RelayParser.NegContext) -> Union[expr.Constant, expr.Call]: - val = self.visit(ctx.expr()) - if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0: - # fold Neg in for scalars - return expr.const(-val.data.asnumpy().item()) - - return op.negative(val) - - def visitTuple(self, ctx: RelayParser.TupleContext) -> expr.Tuple: - tup = self.visit_list(ctx.expr()) - return expr.Tuple(tup) - - def visitLet(self, ctx: RelayParser.LetContext) -> expr.Let: - """Desugar various sequence constructs to Relay Let nodes.""" - - if ctx.var() is None: - # anonymous identity - ident = "_" - typ = None - var = self.mk_var(ident, typ) - else: - var = self.visitVar(ctx.var()) - - self.enter_var_scope() - value = self.visit(ctx.expr(0)) - self.exit_var_scope() - - body = self.visit(ctx.expr(1)) - - return expr.Let(var, value, body) - - def visitBinOp(self, ctx: RelayParser.BinOpContext) -> expr.Call: - """Desugar binary operators.""" - arg0, arg1 = self.visit_list(ctx.expr()) - relay_op = BINARY_OPS.get(ctx.op.type) - - if relay_op is None: - raise ParseError("unimplemented binary op.") - - return relay_op(arg0, arg1) - - @spanify - def visitVar(self, ctx: RelayParser.VarContext) -> expr.Var: - """Visit a single variable.""" - ident = ctx.localVar() - - if ident is None: - raise ParseError("only local ids may be used in vars.") - - typeExpr = self.getTypeExpr(ctx.typeExpr()) - - return self.mk_var(ident.getText()[1:], typeExpr) - - def visitVarList(self, ctx: RelayParser.VarListContext) -> List[expr.Var]: - return self.visit_list(ctx.var()) - - # TODO: support a larger class of values than just Relay exprs - def visitAttr(self, ctx: RelayParser.AttrContext) -> Tuple[str, expr.Expr]: - return (ctx.CNAME().getText(), self.visit(ctx.expr())) - - def visitArgNoAttr(self, ctx: RelayParser.ArgNoAttrContext): - return (self.visit_list(ctx.varList().var()), None) - - def visitAttrSeq(self, ctx: RelayParser.AttrSeqContext) -> Dict[str, expr.Expr]: - return dict(self.visit_list(ctx.attr())) - - def visitArgWithAttr(self, ctx: RelayParser.AttrSeqContext) \ - -> Tuple[List[expr.Var], Dict[str, expr.Expr]]: - return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq())) - - def visitArgList(self, ctx: RelayParser.ArgListContext) \ - -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]: - var_list = self.visit(ctx.varList()) if ctx.varList() else None - attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None - return (var_list, attr_list) - - def visitMeta(self, ctx: RelayParser.MetaContext): - type_key = str(ctx.CNAME()) - index = int(self.visit(ctx.NAT())) - return self.meta[type_key][index] - - def mk_func( - self, - ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \ - -> function.Function: - """Construct a function from either a Func or Defn.""" - # Enter var scope early to put params in scope. - self.enter_var_scope() - # Capture type params in params. - self.enter_type_param_scope() - type_params = ctx.typeParamList() - - if type_params is not None: - type_params = type_params.typeExpr() - assert type_params - for ty_param in type_params: - name = ty_param.getText() - self.mk_typ(name, ty.TypeKind.Type) - - var_list, attr_list = self.visit(ctx.argList()) - if var_list is None: - var_list = [] - ret_type = self.getTypeExpr(ctx.typeExpr()) - - body = self.visit(ctx.body()) - # NB(@jroesch): you must stay in the type parameter scope until - # after you exit the body, you can reference the type parameters - # of your parent scopes. - type_params = list(self.exit_type_param_scope()) - if type_params: - _, type_params = zip(*type_params) - self.exit_var_scope() - - attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None - return function.Function(var_list, body, ret_type, type_params, attrs) - - @spanify - def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function: - return self.mk_func(ctx) - - # TODO: how to set spans for definitions? - # @spanify - def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None: - ident_name = ctx.globalVar().getText()[1:] - ident = self.mk_global_var(ident_name) - func = self.mk_func(ctx) - self.module[ident] = func - - def handle_adt_header( - self, - ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]): - """Handles parsing of the name and type params of an ADT definition.""" - adt_name = ctx.generalIdent().getText() - adt_var = self.mk_global_typ_var(adt_name, ty.TypeKind.AdtHandle) - # parse type params - type_params = ctx.typeParamList() - if type_params is None: - type_params = [] - else: - type_params = [self.mk_typ(type_ident.getText(), ty.TypeKind.Type) - for type_ident in type_params.typeExpr()] - return adt_var, type_params - - def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext): - # TODO(weberlo): update this handler once extern is implemented - self.enter_type_param_scope() - adt_var, type_params = self.handle_adt_header(ctx) - # update module being built - self.module[adt_var] = adt.TypeData(adt_var, type_params, []) - self.exit_type_param_scope() - - def visitAdtDefn(self, ctx: RelayParser.AdtDefnContext): - self.enter_type_param_scope() - adt_var, type_params = self.handle_adt_header(ctx) - # parse constructors - adt_cons_defns = ctx.adtConsDefnList() - if adt_cons_defns is None: - adt_cons_defns = [] - else: - adt_cons_defns = adt_cons_defns.adtConsDefn() - parsed_constructors = [] - for cons_defn in adt_cons_defns: - inputs = [self.visit(inp) for inp in cons_defn.typeExpr()] - cons_defn_name = cons_defn.constructorName().getText() - cons_defn = adt.Constructor(cons_defn_name, inputs, adt_var) - self.mk_global_typ_cons(cons_defn_name, cons_defn) - parsed_constructors.append(cons_defn) - # update module being built - self.module[adt_var] = adt.TypeData(adt_var, type_params, parsed_constructors) - self.exit_type_param_scope() - - def visitMatch(self, ctx: RelayParser.MatchContext): - match_type = ctx.matchType().getText() - if match_type == "match": - complete_match = True - elif match_type == "match?": - complete_match = False - else: - raise RuntimeError("unknown match type {0}".format(match_type)) - - match_data = self.visit(ctx.expr()) - match_clauses = ctx.matchClauseList() - if match_clauses is None: - match_clauses = [] - else: - match_clauses = match_clauses.matchClause() - parsed_clauses = [] - for clause in match_clauses: - self.enter_var_scope() - pattern = self.visit(clause.pattern()) - clause_body = self.visit(clause.expr()) - self.exit_var_scope() - parsed_clauses.append(adt.Clause(pattern, clause_body)) - return adt.Match(match_data, parsed_clauses, complete=complete_match) - - def visitWildcardPattern(self, ctx: RelayParser.WildcardPatternContext): - return adt.PatternWildcard() - - def visitVarPattern(self, ctx: RelayParser.VarPatternContext): - text = ctx.localVar().getText() - typ = ctx.typeExpr() - if typ is not None: - typ = self.visit(typ) - var = self.mk_var(text[1:], typ=typ) - return adt.PatternVar(var) - - def visitConstructorPattern(self, ctx: RelayParser.ConstructorPatternContext): - constructor_name = ctx.constructorName().getText() - constructor = self.global_type_vars[constructor_name] - pattern_list = ctx.patternList() - if pattern_list is None: - patterns = [] - else: - patterns = [self.visit(pattern) for pattern in pattern_list.pattern()] - return adt.PatternConstructor(constructor, patterns) - - def visitTuplePattern(self, ctx: RelayParser.TuplePatternContext): - return adt.PatternTuple([self.visit(pattern) for pattern in ctx.patternList().pattern()]) - - def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext): - return (self.visit_list(ctx.exprList().expr()), None) - - def visitCallWithAttr(self, ctx: RelayParser.CallWithAttrContext): - return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq())) - - def call(self, func, args, attrs, type_args): - if isinstance(func, OpWrapper): - return func(args, attrs, type_args) - if isinstance(func, adt.Constructor): - return func(*args) - return expr.Call(func, args, attrs, type_args) - - @spanify - def visitCall(self, ctx: RelayParser.CallContext) -> expr.Call: - func = self.visit(ctx.expr()) - args, attrs = self.visit(ctx.callList()) - res = self.call(func, args, attrs, []) - return res - - @spanify - def visitIfElse(self, ctx: RelayParser.IfElseContext) -> expr.If: - """Construct a Relay If node. Creates a new scope for each branch.""" - cond = self.visit(ctx.expr()) - - self.enter_var_scope() - true_branch = self.visit(ctx.body(0)) - self.exit_var_scope() - - self.enter_var_scope() - false_branch = self.visit(ctx.body(1)) - self.exit_var_scope() - - return expr.If(cond, true_branch, false_branch) - - @spanify - def visitGraph(self, ctx: RelayParser.GraphContext) -> expr.Expr: - """Visit a graph variable assignment.""" - graph_nid = int(ctx.graphVar().getText()[1:]) - - self.enter_var_scope() - value = self.visit(ctx.expr(0)) - self.exit_var_scope() - - if graph_nid != len(self.graph_expr): - raise ParseError( - "expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \ - "but got `%{}`".format(graph_nid)) - self.graph_expr.append(value) - - kont = self.visit(ctx.expr(1)) - return kont - - # Types - - # pylint: disable=unused-argument - def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext) -> None: - return None - - def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext): - func = self.visit(ctx.generalIdent()) - args = [self.visit(arg) for arg in ctx.typeParamList().typeExpr()] - return ty.TypeCall(func, args) - - def visitParensShape(self, ctx: RelayParser.ParensShapeContext) -> int: - return self.visit(ctx.shape()) - - def visitShapeList(self, ctx: RelayParser.ShapeListContext) -> List[int]: - return self.visit_list(ctx.shape()) - - def visitTensor(self, ctx: RelayParser.TensorContext): - return tuple(self.visit_list(ctx.expr())) - - def visitTensorType(self, ctx: RelayParser.TensorTypeContext) -> ty.TensorType: - """Create a simple tensor type. No generics.""" - - shape = self.visit(ctx.shapeList()) - dtype = self.visit(ctx.typeExpr()) - - if not isinstance(dtype, ty.TensorType): - raise ParseError("expected dtype to be a Relay base type.") - - dtype = dtype.dtype - - return ty.TensorType(shape, dtype) - - def visitTupleType(self, ctx: RelayParser.TupleTypeContext) -> ty.TupleType: - return ty.TupleType(self.visit_list(ctx.typeExpr())) - - def visitFuncType(self, ctx: RelayParser.FuncTypeContext) -> ty.FuncType: - types = self.visit_list(ctx.typeExpr()) - - arg_types = types[:-1] - ret_type = types[-1] - - return ty.FuncType(arg_types, ret_type, [], None) - -def make_parser(data: str) -> RelayParser: - """Construct a RelayParser a given data stream.""" - input_stream = InputStream(data) - lexer = RelayLexer(input_stream) - lexer.addErrorListener(StrictErrorListener(data)) - token_stream = CommonTokenStream(lexer) - p = RelayParser(token_stream) - p.addErrorListener(StrictErrorListener(data)) - return p - -__source_name_counter__ = 0 - -class StrictErrorListener(ErrorListener): - """This ErrorListener fail eagerly on all error, and report the program.""" - def __init__(self, text): - self.text = text - - def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): - raise Exception("Syntax Error in:\n" + self.text) - - def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs): - raise Exception("Ambiguity Error in:\n" + self.text) - - def reportAttemptingFullContext(self, - recognizer, - dfa, - startIndex, - stopIndex, - conflictingAlts, - configs): - raise Exception("Attempting Full Context in:\n" + self.text) - - def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs): - raise Exception("Context Sensitivity in:\n" + self.text) - -def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, IRModule]: - """Parse a Relay program.""" - if data == "": - raise ParseError("cannot parse the empty string.") - - global __source_name_counter__ - - if source_name is None: - source_name = "source_file{0}".format(__source_name_counter__) - - if isinstance(source_name, str): - source_name = SourceName(source_name) - - tree = make_parser(data).prog() - return ParseTreeToRelayIR(source_name).visit(tree) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index fbb98fcf9e3c4..106edc25c5ee3 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -518,7 +518,7 @@ def bind(expr, binds): expr : tvm.relay.Expr The input expression. - binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]] + binds : Map[tvm.relay.Var, tvm.relay.Expr] The specific bindings. Returns diff --git a/python/tvm/relay/grammar/.gitignore b/python/tvm/relay/grammar/.gitignore deleted file mode 100644 index cffe35e1a41a8..0000000000000 --- a/python/tvm/relay/grammar/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/.antlr/ diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 deleted file mode 100644 index bfcd18ffc98f1..0000000000000 --- a/python/tvm/relay/grammar/Relay.g4 +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * NOTE: The `USE_ANTLR` option in `config.cmake` must be enabled in order for - * changes in this file to be reflected by the parser. - * NOTE: All upper-case rules are *lexer* rules and all camel-case rules are *parser* rules. - */ - -grammar Relay; - -SEMVER: 'v0.0.4' ; - -// Lexing -// comments -COMMENT : '/*' (COMMENT|.)*? '*/' -> skip; -WS : [ \t\n\r]+ -> skip; -LINE_COMMENT : '//' .*? '\n' -> skip; - -fragment ESCAPED_QUOTE : '\\"'; -QUOTED_STRING : '"' ( ESCAPED_QUOTE | ~('\n'|'\r') )*? '"'; - -// operators -MUL: '*' ; -DIV: '/' ; -ADD: '+' ; -SUB: '-' ; -LT: '<' ; -GT: '>' ; -LE: '<=' ; -GE: '>=' ; -EQ: '==' ; -NE: '!=' ; - -BOOL_LIT - : 'True' - | 'False' - ; - -CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ; - -// non-negative floats -fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4 - -FLOAT : PREFLOAT 'f'; - -// non-negative ints -NAT: DIGIT+ ; -fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...] - -fragment LETTER: [a-zA-Z]; -fragment DIGIT: [0-9]; - -METADATA: 'METADATA:' .*; -// Parsing - -// A Relay program is a list of global definitions or an expression. -prog: SEMVER (defn* | expr) METADATA? EOF ; - -// Covers both operator and type idents -generalIdent: CNAME ('.' CNAME)*; -globalVar: '@' CNAME ; -localVar: '%' ('_' | CNAME) ; -graphVar: '%' NAT ; - -exprList: (expr (',' expr)*)?; -callList - : exprList # callNoAttr - | (expr ',')* attrSeq # callWithAttr - ; - -expr - // operators - : '(' expr ')' # paren - // function application - | expr '(' callList ')' # call - | '-' expr # neg - | expr op=('*'|'/') expr # binOp - | expr op=('+'|'-') expr # binOp - | expr op=('<'|'>'|'<='|'>=') expr # binOp - | expr op=('=='|'!=') expr # binOp - // function definition - | func # funcExpr - // tuples and tensors - | '(' ')' # tuple - | '(' expr ',' ')' # tuple - | '(' expr (',' expr)+ ')' # tuple - | '[' (expr (',' expr)*)? ']' # tensor - | 'if' '(' expr ')' body 'else' body # ifElse - | matchType expr '{' matchClauseList? '}' # match - | expr '.' NAT # projection - // sequencing - | 'let' var '=' expr ';' expr # let - // sugar for let %_ = expr; expr - | expr ';;' expr # let - | graphVar '=' expr ';' expr # graph - | ident # identExpr - | scalar # scalarExpr - | meta # metaExpr - | QUOTED_STRING # stringExpr - ; - -func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ; -defn - : 'def' globalVar typeParamList? '(' argList ')' ('->' typeExpr)? body # funcDefn - | 'extern' 'type' generalIdent typeParamList? # externAdtDefn - | 'type' generalIdent typeParamList? '{' adtConsDefnList? '}' # adtDefn - ; - -constructorName: CNAME ; - -adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ; -adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ; -matchClauseList: matchClause (',' matchClause)* ','? ; -matchClause: pattern '=>' ('{' expr '}' | expr) ; -// complete or incomplete match, respectively -matchType : 'match' | 'match?' ; - -patternList: '(' pattern (',' pattern)* ')'; -pattern - : '_' # wildcardPattern - | localVar (':' typeExpr)? # varPattern - | constructorName patternList? # constructorPattern - | patternList # tuplePattern - ; - -adtCons: constructorName adtConsParamList? ; -adtConsParamList: '(' adtConsParam (',' adtConsParam)* ')' ; -adtConsParam: localVar | constructorName ; - -argList - : varList # argNoAttr - | (var ',')* attrSeq # argWithAttr - ; - -varList: (var (',' var)*)? ; -var: localVar (':' typeExpr)? ; - -attrSeq: attr (',' attr)* ; -attr: CNAME '=' expr ; - -typeExpr - : '(' ')' # tupleType - | '(' typeExpr ')' # typeParen - | '(' typeExpr ',' ')' # tupleType - | '(' typeExpr (',' typeExpr)+ ')' # tupleType - | generalIdent typeParamList # typeCallType - | generalIdent # typeIdentType - | 'Tensor' '[' shapeList ',' typeExpr ']' # tensorType - | 'fn' typeParamList? '(' (typeExpr (',' typeExpr)*)? ')' '->' typeExpr # funcType - | '_' # incompleteType - ; - -typeParamList: '[' typeExpr (',' typeExpr)* ']' ; - -shapeList - : '(' ')' - | '(' shape (',' shape)+ ')' - | shape - ; - -meta : 'meta' '[' CNAME ']' '[' NAT ']'; - -shape - : meta # metaShape - | '(' shape ')' # parensShape - | NAT # intShape - ; - -body: '{' expr '}' ; - -scalar - : FLOAT # scalarFloat - | NAT # scalarInt - | BOOL_LIT # scalarBool - ; - -ident - : generalIdent - | globalVar - | localVar - | graphVar - ; diff --git a/python/tvm/relay/grammar/__init__.py b/python/tvm/relay/grammar/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/python/tvm/relay/grammar/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/python/tvm/relay/grammar/py3/.gitattributes b/python/tvm/relay/grammar/py3/.gitattributes deleted file mode 100644 index 0eaf9078bc4f6..0000000000000 --- a/python/tvm/relay/grammar/py3/.gitattributes +++ /dev/null @@ -1,3 +0,0 @@ -Relay* binary -Relay* linguist-generated=true -Relay* linguist-detectable=false diff --git a/python/tvm/relay/grammar/py3/RelayLexer.py b/python/tvm/relay/grammar/py3/RelayLexer.py deleted file mode 100644 index 76e988b454180..0000000000000 --- a/python/tvm/relay/grammar/py3/RelayLexer.py +++ /dev/null @@ -1,256 +0,0 @@ -# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 -from antlr4 import * -from io import StringIO -from typing.io import TextIO -import sys - - - -def serializedATN(): - with StringIO() as buf: - buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2\62") - buf.write("\u0161\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7") - buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r") - buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23") - buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30") - buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36") - buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%") - buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\4.") - buf.write("\t.\4/\t/\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\4\64") - buf.write("\t\64\4\65\t\65\4\66\t\66\3\2\3\2\3\3\3\3\3\4\3\4\3\5") - buf.write("\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3\t\3\t\3\n\3\n\3\13\3\13") - buf.write("\3\13\3\f\3\f\3\f\3\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17") - buf.write("\3\17\3\17\3\20\3\20\3\21\3\21\3\22\3\22\3\22\3\23\3\23") - buf.write("\3\23\3\24\3\24\3\24\3\25\3\25\3\25\3\25\3\26\3\26\3\26") - buf.write("\3\26\3\26\3\26\3\26\3\27\3\27\3\27\3\27\3\27\3\30\3\30") - buf.write("\3\30\3\31\3\31\3\31\3\31\3\31\3\31\3\32\3\32\3\32\3\32") - buf.write("\3\32\3\32\3\32\3\33\3\33\3\34\3\34\3\34\3\34\3\34\3\34") - buf.write("\3\34\3\35\3\35\3\35\3\35\3\35\3\36\3\36\3\36\3\36\3\36") - buf.write("\3\36\3\36\3\37\3\37\3\37\3\37\3\37\7\37\u00d7\n\37\f") - buf.write("\37\16\37\u00da\13\37\3\37\3\37\3\37\3\37\3\37\3 \6 \u00e2") - buf.write("\n \r \16 \u00e3\3 \3 \3!\3!\3!\3!\7!\u00ec\n!\f!\16!") - buf.write("\u00ef\13!\3!\3!\3!\3!\3\"\3\"\3\"\3#\3#\3#\7#\u00fb\n") - buf.write("#\f#\16#\u00fe\13#\3#\3#\3$\3$\3%\3%\3&\3&\3\'\3\'\3(") - buf.write("\3(\3)\3)\3*\3*\3*\3+\3+\3+\3,\3,\3,\3-\3-\3-\3.\3.\3") - buf.write(".\3.\3.\3.\3.\3.\3.\5.\u0123\n.\3/\3/\5/\u0127\n/\3/\3") - buf.write("/\3/\7/\u012c\n/\f/\16/\u012f\13/\3/\3/\7/\u0133\n/\f") - buf.write("/\16/\u0136\13/\3\60\3\60\3\60\5\60\u013b\n\60\3\60\5") - buf.write("\60\u013e\n\60\3\61\3\61\3\61\3\62\6\62\u0144\n\62\r\62") - buf.write("\16\62\u0145\3\63\3\63\5\63\u014a\n\63\3\63\3\63\3\64") - buf.write("\3\64\3\65\3\65\3\66\3\66\3\66\3\66\3\66\3\66\3\66\3\66") - buf.write("\3\66\3\66\3\66\7\66\u015d\n\66\f\66\16\66\u0160\13\66") - buf.write("\5\u00d8\u00ed\u00fc\2\67\3\3\5\4\7\5\t\6\13\7\r\b\17") - buf.write("\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21!\22#\23") - buf.write("%\24\'\25)\26+\27-\30/\31\61\32\63\33\65\34\67\359\36") - buf.write(";\37= ?!A\"C\2E#G$I%K&M\'O(Q)S*U+W,Y-[.]/_\2a\60c\61e") - buf.write("\2g\2i\2k\62\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4") - buf.write("\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u016c\2\3\3\2\2\2\2") - buf.write("\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3") - buf.write("\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3\2") - buf.write("\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3\2\2") - buf.write("\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2\2\'\3") - buf.write("\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2\2\2\2\61") - buf.write("\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2\2\67\3\2\2\2\29\3\2") - buf.write("\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2\2\2\2A\3\2\2\2\2E\3") - buf.write("\2\2\2\2G\3\2\2\2\2I\3\2\2\2\2K\3\2\2\2\2M\3\2\2\2\2O") - buf.write("\3\2\2\2\2Q\3\2\2\2\2S\3\2\2\2\2U\3\2\2\2\2W\3\2\2\2\2") - buf.write("Y\3\2\2\2\2[\3\2\2\2\2]\3\2\2\2\2a\3\2\2\2\2c\3\2\2\2") - buf.write("\2k\3\2\2\2\3m\3\2\2\2\5o\3\2\2\2\7q\3\2\2\2\ts\3\2\2") - buf.write("\2\13u\3\2\2\2\rw\3\2\2\2\17y\3\2\2\2\21{\3\2\2\2\23}") - buf.write("\3\2\2\2\25\177\3\2\2\2\27\u0082\3\2\2\2\31\u0087\3\2") - buf.write("\2\2\33\u0089\3\2\2\2\35\u008b\3\2\2\2\37\u008f\3\2\2") - buf.write("\2!\u0091\3\2\2\2#\u0093\3\2\2\2%\u0096\3\2\2\2\'\u0099") - buf.write("\3\2\2\2)\u009c\3\2\2\2+\u00a0\3\2\2\2-\u00a7\3\2\2\2") - buf.write("/\u00ac\3\2\2\2\61\u00af\3\2\2\2\63\u00b5\3\2\2\2\65\u00bc") - buf.write("\3\2\2\2\67\u00be\3\2\2\29\u00c5\3\2\2\2;\u00ca\3\2\2") - buf.write("\2=\u00d1\3\2\2\2?\u00e1\3\2\2\2A\u00e7\3\2\2\2C\u00f4") - buf.write("\3\2\2\2E\u00f7\3\2\2\2G\u0101\3\2\2\2I\u0103\3\2\2\2") - buf.write("K\u0105\3\2\2\2M\u0107\3\2\2\2O\u0109\3\2\2\2Q\u010b\3") - buf.write("\2\2\2S\u010d\3\2\2\2U\u0110\3\2\2\2W\u0113\3\2\2\2Y\u0116") - buf.write("\3\2\2\2[\u0122\3\2\2\2]\u0126\3\2\2\2_\u0137\3\2\2\2") - buf.write("a\u013f\3\2\2\2c\u0143\3\2\2\2e\u0147\3\2\2\2g\u014d\3") - buf.write("\2\2\2i\u014f\3\2\2\2k\u0151\3\2\2\2mn\7\60\2\2n\4\3\2") - buf.write("\2\2op\7B\2\2p\6\3\2\2\2qr\7\'\2\2r\b\3\2\2\2st\7a\2\2") - buf.write("t\n\3\2\2\2uv\7.\2\2v\f\3\2\2\2wx\7*\2\2x\16\3\2\2\2y") - buf.write("z\7+\2\2z\20\3\2\2\2{|\7]\2\2|\22\3\2\2\2}~\7_\2\2~\24") - buf.write("\3\2\2\2\177\u0080\7k\2\2\u0080\u0081\7h\2\2\u0081\26") - buf.write("\3\2\2\2\u0082\u0083\7g\2\2\u0083\u0084\7n\2\2\u0084\u0085") - buf.write("\7u\2\2\u0085\u0086\7g\2\2\u0086\30\3\2\2\2\u0087\u0088") - buf.write("\7}\2\2\u0088\32\3\2\2\2\u0089\u008a\7\177\2\2\u008a\34") - buf.write("\3\2\2\2\u008b\u008c\7n\2\2\u008c\u008d\7g\2\2\u008d\u008e") - buf.write("\7v\2\2\u008e\36\3\2\2\2\u008f\u0090\7?\2\2\u0090 \3\2") - buf.write("\2\2\u0091\u0092\7=\2\2\u0092\"\3\2\2\2\u0093\u0094\7") - buf.write("=\2\2\u0094\u0095\7=\2\2\u0095$\3\2\2\2\u0096\u0097\7") - buf.write("h\2\2\u0097\u0098\7p\2\2\u0098&\3\2\2\2\u0099\u009a\7") - buf.write("/\2\2\u009a\u009b\7@\2\2\u009b(\3\2\2\2\u009c\u009d\7") - buf.write("f\2\2\u009d\u009e\7g\2\2\u009e\u009f\7h\2\2\u009f*\3\2") - buf.write("\2\2\u00a0\u00a1\7g\2\2\u00a1\u00a2\7z\2\2\u00a2\u00a3") - buf.write("\7v\2\2\u00a3\u00a4\7g\2\2\u00a4\u00a5\7t\2\2\u00a5\u00a6") - buf.write("\7p\2\2\u00a6,\3\2\2\2\u00a7\u00a8\7v\2\2\u00a8\u00a9") - buf.write("\7{\2\2\u00a9\u00aa\7r\2\2\u00aa\u00ab\7g\2\2\u00ab.\3") - buf.write("\2\2\2\u00ac\u00ad\7?\2\2\u00ad\u00ae\7@\2\2\u00ae\60") - buf.write("\3\2\2\2\u00af\u00b0\7o\2\2\u00b0\u00b1\7c\2\2\u00b1\u00b2") - buf.write("\7v\2\2\u00b2\u00b3\7e\2\2\u00b3\u00b4\7j\2\2\u00b4\62") - buf.write("\3\2\2\2\u00b5\u00b6\7o\2\2\u00b6\u00b7\7c\2\2\u00b7\u00b8") - buf.write("\7v\2\2\u00b8\u00b9\7e\2\2\u00b9\u00ba\7j\2\2\u00ba\u00bb") - buf.write("\7A\2\2\u00bb\64\3\2\2\2\u00bc\u00bd\7<\2\2\u00bd\66\3") - buf.write("\2\2\2\u00be\u00bf\7V\2\2\u00bf\u00c0\7g\2\2\u00c0\u00c1") - buf.write("\7p\2\2\u00c1\u00c2\7u\2\2\u00c2\u00c3\7q\2\2\u00c3\u00c4") - buf.write("\7t\2\2\u00c48\3\2\2\2\u00c5\u00c6\7o\2\2\u00c6\u00c7") - buf.write("\7g\2\2\u00c7\u00c8\7v\2\2\u00c8\u00c9\7c\2\2\u00c9:\3") - buf.write("\2\2\2\u00ca\u00cb\7x\2\2\u00cb\u00cc\7\62\2\2\u00cc\u00cd") - buf.write("\7\60\2\2\u00cd\u00ce\7\62\2\2\u00ce\u00cf\7\60\2\2\u00cf") - buf.write("\u00d0\7\66\2\2\u00d0<\3\2\2\2\u00d1\u00d2\7\61\2\2\u00d2") - buf.write("\u00d3\7,\2\2\u00d3\u00d8\3\2\2\2\u00d4\u00d7\5=\37\2") - buf.write("\u00d5\u00d7\13\2\2\2\u00d6\u00d4\3\2\2\2\u00d6\u00d5") - buf.write("\3\2\2\2\u00d7\u00da\3\2\2\2\u00d8\u00d9\3\2\2\2\u00d8") - buf.write("\u00d6\3\2\2\2\u00d9\u00db\3\2\2\2\u00da\u00d8\3\2\2\2") - buf.write("\u00db\u00dc\7,\2\2\u00dc\u00dd\7\61\2\2\u00dd\u00de\3") - buf.write("\2\2\2\u00de\u00df\b\37\2\2\u00df>\3\2\2\2\u00e0\u00e2") - buf.write("\t\2\2\2\u00e1\u00e0\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3") - buf.write("\u00e1\3\2\2\2\u00e3\u00e4\3\2\2\2\u00e4\u00e5\3\2\2\2") - buf.write("\u00e5\u00e6\b \2\2\u00e6@\3\2\2\2\u00e7\u00e8\7\61\2") - buf.write("\2\u00e8\u00e9\7\61\2\2\u00e9\u00ed\3\2\2\2\u00ea\u00ec") - buf.write("\13\2\2\2\u00eb\u00ea\3\2\2\2\u00ec\u00ef\3\2\2\2\u00ed") - buf.write("\u00ee\3\2\2\2\u00ed\u00eb\3\2\2\2\u00ee\u00f0\3\2\2\2") - buf.write("\u00ef\u00ed\3\2\2\2\u00f0\u00f1\7\f\2\2\u00f1\u00f2\3") - buf.write("\2\2\2\u00f2\u00f3\b!\2\2\u00f3B\3\2\2\2\u00f4\u00f5\7") - buf.write("^\2\2\u00f5\u00f6\7$\2\2\u00f6D\3\2\2\2\u00f7\u00fc\7") - buf.write("$\2\2\u00f8\u00fb\5C\"\2\u00f9\u00fb\n\3\2\2\u00fa\u00f8") - buf.write("\3\2\2\2\u00fa\u00f9\3\2\2\2\u00fb\u00fe\3\2\2\2\u00fc") - buf.write("\u00fd\3\2\2\2\u00fc\u00fa\3\2\2\2\u00fd\u00ff\3\2\2\2") - buf.write("\u00fe\u00fc\3\2\2\2\u00ff\u0100\7$\2\2\u0100F\3\2\2\2") - buf.write("\u0101\u0102\7,\2\2\u0102H\3\2\2\2\u0103\u0104\7\61\2") - buf.write("\2\u0104J\3\2\2\2\u0105\u0106\7-\2\2\u0106L\3\2\2\2\u0107") - buf.write("\u0108\7/\2\2\u0108N\3\2\2\2\u0109\u010a\7>\2\2\u010a") - buf.write("P\3\2\2\2\u010b\u010c\7@\2\2\u010cR\3\2\2\2\u010d\u010e") - buf.write("\7>\2\2\u010e\u010f\7?\2\2\u010fT\3\2\2\2\u0110\u0111") - buf.write("\7@\2\2\u0111\u0112\7?\2\2\u0112V\3\2\2\2\u0113\u0114") - buf.write("\7?\2\2\u0114\u0115\7?\2\2\u0115X\3\2\2\2\u0116\u0117") - buf.write("\7#\2\2\u0117\u0118\7?\2\2\u0118Z\3\2\2\2\u0119\u011a") - buf.write("\7V\2\2\u011a\u011b\7t\2\2\u011b\u011c\7w\2\2\u011c\u0123") - buf.write("\7g\2\2\u011d\u011e\7H\2\2\u011e\u011f\7c\2\2\u011f\u0120") - buf.write("\7n\2\2\u0120\u0121\7u\2\2\u0121\u0123\7g\2\2\u0122\u0119") - buf.write("\3\2\2\2\u0122\u011d\3\2\2\2\u0123\\\3\2\2\2\u0124\u0127") - buf.write("\7a\2\2\u0125\u0127\5g\64\2\u0126\u0124\3\2\2\2\u0126") - buf.write("\u0125\3\2\2\2\u0127\u012d\3\2\2\2\u0128\u012c\7a\2\2") - buf.write("\u0129\u012c\5g\64\2\u012a\u012c\5i\65\2\u012b\u0128\3") - buf.write("\2\2\2\u012b\u0129\3\2\2\2\u012b\u012a\3\2\2\2\u012c\u012f") - buf.write("\3\2\2\2\u012d\u012b\3\2\2\2\u012d\u012e\3\2\2\2\u012e") - buf.write("\u0134\3\2\2\2\u012f\u012d\3\2\2\2\u0130\u0131\7\60\2") - buf.write("\2\u0131\u0133\5]/\2\u0132\u0130\3\2\2\2\u0133\u0136\3") - buf.write("\2\2\2\u0134\u0132\3\2\2\2\u0134\u0135\3\2\2\2\u0135^") - buf.write("\3\2\2\2\u0136\u0134\3\2\2\2\u0137\u013a\5c\62\2\u0138") - buf.write("\u0139\7\60\2\2\u0139\u013b\5c\62\2\u013a\u0138\3\2\2") - buf.write("\2\u013a\u013b\3\2\2\2\u013b\u013d\3\2\2\2\u013c\u013e") - buf.write("\5e\63\2\u013d\u013c\3\2\2\2\u013d\u013e\3\2\2\2\u013e") - buf.write("`\3\2\2\2\u013f\u0140\5_\60\2\u0140\u0141\7h\2\2\u0141") - buf.write("b\3\2\2\2\u0142\u0144\5i\65\2\u0143\u0142\3\2\2\2\u0144") - buf.write("\u0145\3\2\2\2\u0145\u0143\3\2\2\2\u0145\u0146\3\2\2\2") - buf.write("\u0146d\3\2\2\2\u0147\u0149\t\4\2\2\u0148\u014a\t\5\2") - buf.write("\2\u0149\u0148\3\2\2\2\u0149\u014a\3\2\2\2\u014a\u014b") - buf.write("\3\2\2\2\u014b\u014c\5c\62\2\u014cf\3\2\2\2\u014d\u014e") - buf.write("\t\6\2\2\u014eh\3\2\2\2\u014f\u0150\t\7\2\2\u0150j\3\2") - buf.write("\2\2\u0151\u0152\7O\2\2\u0152\u0153\7G\2\2\u0153\u0154") - buf.write("\7V\2\2\u0154\u0155\7C\2\2\u0155\u0156\7F\2\2\u0156\u0157") - buf.write("\7C\2\2\u0157\u0158\7V\2\2\u0158\u0159\7C\2\2\u0159\u015a") - buf.write("\7<\2\2\u015a\u015e\3\2\2\2\u015b\u015d\13\2\2\2\u015c") - buf.write("\u015b\3\2\2\2\u015d\u0160\3\2\2\2\u015e\u015c\3\2\2\2") - buf.write("\u015e\u015f\3\2\2\2\u015fl\3\2\2\2\u0160\u015e\3\2\2") - buf.write("\2\23\2\u00d6\u00d8\u00e3\u00ed\u00fa\u00fc\u0122\u0126") - buf.write("\u012b\u012d\u0134\u013a\u013d\u0145\u0149\u015e\3\b\2") - buf.write("\2") - return buf.getvalue() - - -class RelayLexer(Lexer): - - atn = ATNDeserializer().deserialize(serializedATN()) - - decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] - - T__0 = 1 - T__1 = 2 - T__2 = 3 - T__3 = 4 - T__4 = 5 - T__5 = 6 - T__6 = 7 - T__7 = 8 - T__8 = 9 - T__9 = 10 - T__10 = 11 - T__11 = 12 - T__12 = 13 - T__13 = 14 - T__14 = 15 - T__15 = 16 - T__16 = 17 - T__17 = 18 - T__18 = 19 - T__19 = 20 - T__20 = 21 - T__21 = 22 - T__22 = 23 - T__23 = 24 - T__24 = 25 - T__25 = 26 - T__26 = 27 - T__27 = 28 - SEMVER = 29 - COMMENT = 30 - WS = 31 - LINE_COMMENT = 32 - QUOTED_STRING = 33 - MUL = 34 - DIV = 35 - ADD = 36 - SUB = 37 - LT = 38 - GT = 39 - LE = 40 - GE = 41 - EQ = 42 - NE = 43 - BOOL_LIT = 44 - CNAME = 45 - FLOAT = 46 - NAT = 47 - METADATA = 48 - - channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] - - modeNames = [ "DEFAULT_MODE" ] - - literalNames = [ "", - "'.'", "'@'", "'%'", "'_'", "','", "'('", "')'", "'['", "']'", - "'if'", "'else'", "'{'", "'}'", "'let'", "'='", "';'", "';;'", - "'fn'", "'->'", "'def'", "'extern'", "'type'", "'=>'", "'match'", - "'match?'", "':'", "'Tensor'", "'meta'", "'v0.0.4'", "'*'", - "'/'", "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", "'!='" ] - - symbolicNames = [ "", - "SEMVER", "COMMENT", "WS", "LINE_COMMENT", "QUOTED_STRING", - "MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", - "BOOL_LIT", "CNAME", "FLOAT", "NAT", "METADATA" ] - - ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", - "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", - "T__14", "T__15", "T__16", "T__17", "T__18", "T__19", - "T__20", "T__21", "T__22", "T__23", "T__24", "T__25", - "T__26", "T__27", "SEMVER", "COMMENT", "WS", "LINE_COMMENT", - "ESCAPED_QUOTE", "QUOTED_STRING", "MUL", "DIV", "ADD", - "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "BOOL_LIT", - "CNAME", "PREFLOAT", "FLOAT", "NAT", "EXP", "LETTER", - "DIGIT", "METADATA" ] - - grammarFileName = "Relay.g4" - - def __init__(self, input=None, output:TextIO = sys.stdout): - super().__init__(input, output) - self.checkVersion("4.7.2") - self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) - self._actions = None - self._predicates = None - - diff --git a/python/tvm/relay/grammar/py3/RelayParser.py b/python/tvm/relay/grammar/py3/RelayParser.py deleted file mode 100644 index f24eed4be92f7..0000000000000 --- a/python/tvm/relay/grammar/py3/RelayParser.py +++ /dev/null @@ -1,3732 +0,0 @@ -# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 -# encoding: utf-8 -from antlr4 import * -from io import StringIO -from typing.io import TextIO -import sys - - -def serializedATN(): - with StringIO() as buf: - buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3\62") - buf.write("\u0200\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t\7") - buf.write("\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r\4\16") - buf.write("\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23\t\23") - buf.write("\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30\4\31") - buf.write("\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36\t\36") - buf.write("\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\3\2\3\2\7\2I\n\2") - buf.write("\f\2\16\2L\13\2\3\2\5\2O\n\2\3\2\5\2R\n\2\3\2\3\2\3\3") - buf.write("\3\3\3\3\7\3Y\n\3\f\3\16\3\\\13\3\3\4\3\4\3\4\3\5\3\5") - buf.write("\3\5\3\6\3\6\3\6\3\7\3\7\3\7\7\7j\n\7\f\7\16\7m\13\7\5") - buf.write("\7o\n\7\3\b\3\b\3\b\3\b\7\bu\n\b\f\b\16\bx\13\b\3\b\5") - buf.write("\b{\n\b\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3") - buf.write("\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\6\t\u0090\n\t\r\t\16\t") - buf.write("\u0091\3\t\3\t\3\t\3\t\3\t\3\t\7\t\u009a\n\t\f\t\16\t") - buf.write("\u009d\13\t\5\t\u009f\n\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t") - buf.write("\3\t\3\t\3\t\3\t\3\t\3\t\5\t\u00ae\n\t\3\t\3\t\3\t\3\t") - buf.write("\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3") - buf.write("\t\3\t\5\t\u00c3\n\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3") - buf.write("\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t") - buf.write("\3\t\7\t\u00dc\n\t\f\t\16\t\u00df\13\t\3\n\3\n\5\n\u00e3") - buf.write("\n\n\3\n\3\n\3\n\3\n\3\n\5\n\u00ea\n\n\3\n\3\n\3\13\3") - buf.write("\13\3\13\5\13\u00f1\n\13\3\13\3\13\3\13\3\13\3\13\5\13") - buf.write("\u00f8\n\13\3\13\3\13\3\13\3\13\3\13\3\13\5\13\u0100\n") - buf.write("\13\3\13\3\13\3\13\5\13\u0105\n\13\3\13\3\13\5\13\u0109") - buf.write("\n\13\3\13\3\13\5\13\u010d\n\13\3\f\3\f\3\r\3\r\3\r\7") - buf.write("\r\u0114\n\r\f\r\16\r\u0117\13\r\3\r\5\r\u011a\n\r\3\16") - buf.write("\3\16\3\16\3\16\3\16\7\16\u0121\n\16\f\16\16\16\u0124") - buf.write("\13\16\3\16\3\16\5\16\u0128\n\16\3\17\3\17\3\17\7\17\u012d") - buf.write("\n\17\f\17\16\17\u0130\13\17\3\17\5\17\u0133\n\17\3\20") - buf.write("\3\20\3\20\3\20\3\20\3\20\3\20\5\20\u013c\n\20\3\21\3") - buf.write("\21\3\22\3\22\3\22\3\22\7\22\u0144\n\22\f\22\16\22\u0147") - buf.write("\13\22\3\22\3\22\3\23\3\23\3\23\3\23\5\23\u014f\n\23\3") - buf.write("\23\3\23\5\23\u0153\n\23\3\23\5\23\u0156\n\23\3\24\3\24") - buf.write("\5\24\u015a\n\24\3\25\3\25\3\25\3\25\7\25\u0160\n\25\f") - buf.write("\25\16\25\u0163\13\25\3\25\3\25\3\26\3\26\5\26\u0169\n") - buf.write("\26\3\27\3\27\3\27\3\27\7\27\u016f\n\27\f\27\16\27\u0172") - buf.write("\13\27\3\27\5\27\u0175\n\27\3\30\3\30\3\30\7\30\u017a") - buf.write("\n\30\f\30\16\30\u017d\13\30\5\30\u017f\n\30\3\31\3\31") - buf.write("\3\31\5\31\u0184\n\31\3\32\3\32\3\32\7\32\u0189\n\32\f") - buf.write("\32\16\32\u018c\13\32\3\33\3\33\3\33\3\33\3\34\3\34\3") - buf.write("\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34") - buf.write("\3\34\3\34\6\34\u01a1\n\34\r\34\16\34\u01a2\3\34\3\34") - buf.write("\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34") - buf.write("\3\34\3\34\5\34\u01b4\n\34\3\34\3\34\3\34\3\34\7\34\u01ba") - buf.write("\n\34\f\34\16\34\u01bd\13\34\5\34\u01bf\n\34\3\34\3\34") - buf.write("\3\34\3\34\5\34\u01c5\n\34\3\35\3\35\3\35\3\35\7\35\u01cb") - buf.write("\n\35\f\35\16\35\u01ce\13\35\3\35\3\35\3\36\3\36\3\36") - buf.write("\3\36\3\36\3\36\6\36\u01d8\n\36\r\36\16\36\u01d9\3\36") - buf.write("\3\36\3\36\5\36\u01df\n\36\3\37\3\37\3\37\3\37\3\37\3") - buf.write("\37\3\37\3\37\3 \3 \3 \3 \3 \3 \5 \u01ef\n \3!\3!\3!\3") - buf.write("!\3\"\3\"\3\"\5\"\u01f8\n\"\3#\3#\3#\3#\5#\u01fe\n#\3") - buf.write("#\2\3\20$\2\4\6\b\n\f\16\20\22\24\26\30\32\34\36 \"$&") - buf.write("(*,.\60\62\64\668:<>@BD\2\b\4\2\6\6//\3\2$%\3\2&\'\3\2") - buf.write("(+\3\2,-\3\2\32\33\2\u0234\2F\3\2\2\2\4U\3\2\2\2\6]\3") - buf.write("\2\2\2\b`\3\2\2\2\nc\3\2\2\2\fn\3\2\2\2\16z\3\2\2\2\20") - buf.write("\u00c2\3\2\2\2\22\u00e0\3\2\2\2\24\u010c\3\2\2\2\26\u010e") - buf.write("\3\2\2\2\30\u0110\3\2\2\2\32\u011b\3\2\2\2\34\u0129\3") - buf.write("\2\2\2\36\u0134\3\2\2\2 \u013d\3\2\2\2\"\u013f\3\2\2\2") - buf.write("$\u0155\3\2\2\2&\u0157\3\2\2\2(\u015b\3\2\2\2*\u0168\3") - buf.write("\2\2\2,\u0174\3\2\2\2.\u017e\3\2\2\2\60\u0180\3\2\2\2") - buf.write("\62\u0185\3\2\2\2\64\u018d\3\2\2\2\66\u01c4\3\2\2\28\u01c6") - buf.write("\3\2\2\2:\u01de\3\2\2\2<\u01e0\3\2\2\2>\u01ee\3\2\2\2") - buf.write("@\u01f0\3\2\2\2B\u01f7\3\2\2\2D\u01fd\3\2\2\2FN\7\37\2") - buf.write("\2GI\5\24\13\2HG\3\2\2\2IL\3\2\2\2JH\3\2\2\2JK\3\2\2\2") - buf.write("KO\3\2\2\2LJ\3\2\2\2MO\5\20\t\2NJ\3\2\2\2NM\3\2\2\2OQ") - buf.write("\3\2\2\2PR\7\62\2\2QP\3\2\2\2QR\3\2\2\2RS\3\2\2\2ST\7") - buf.write("\2\2\3T\3\3\2\2\2UZ\7/\2\2VW\7\3\2\2WY\7/\2\2XV\3\2\2") - buf.write("\2Y\\\3\2\2\2ZX\3\2\2\2Z[\3\2\2\2[\5\3\2\2\2\\Z\3\2\2") - buf.write("\2]^\7\4\2\2^_\7/\2\2_\7\3\2\2\2`a\7\5\2\2ab\t\2\2\2b") - buf.write("\t\3\2\2\2cd\7\5\2\2de\7\61\2\2e\13\3\2\2\2fk\5\20\t\2") - buf.write("gh\7\7\2\2hj\5\20\t\2ig\3\2\2\2jm\3\2\2\2ki\3\2\2\2kl") - buf.write("\3\2\2\2lo\3\2\2\2mk\3\2\2\2nf\3\2\2\2no\3\2\2\2o\r\3") - buf.write("\2\2\2p{\5\f\7\2qr\5\20\t\2rs\7\7\2\2su\3\2\2\2tq\3\2") - buf.write("\2\2ux\3\2\2\2vt\3\2\2\2vw\3\2\2\2wy\3\2\2\2xv\3\2\2\2") - buf.write("y{\5\62\32\2zp\3\2\2\2zv\3\2\2\2{\17\3\2\2\2|}\b\t\1\2") - buf.write("}~\7\b\2\2~\177\5\20\t\2\177\u0080\7\t\2\2\u0080\u00c3") - buf.write("\3\2\2\2\u0081\u0082\7\'\2\2\u0082\u00c3\5\20\t\26\u0083") - buf.write("\u00c3\5\22\n\2\u0084\u0085\7\b\2\2\u0085\u00c3\7\t\2") - buf.write("\2\u0086\u0087\7\b\2\2\u0087\u0088\5\20\t\2\u0088\u0089") - buf.write("\7\7\2\2\u0089\u008a\7\t\2\2\u008a\u00c3\3\2\2\2\u008b") - buf.write("\u008c\7\b\2\2\u008c\u008f\5\20\t\2\u008d\u008e\7\7\2") - buf.write("\2\u008e\u0090\5\20\t\2\u008f\u008d\3\2\2\2\u0090\u0091") - buf.write("\3\2\2\2\u0091\u008f\3\2\2\2\u0091\u0092\3\2\2\2\u0092") - buf.write("\u0093\3\2\2\2\u0093\u0094\7\t\2\2\u0094\u00c3\3\2\2\2") - buf.write("\u0095\u009e\7\n\2\2\u0096\u009b\5\20\t\2\u0097\u0098") - buf.write("\7\7\2\2\u0098\u009a\5\20\t\2\u0099\u0097\3\2\2\2\u009a") - buf.write("\u009d\3\2\2\2\u009b\u0099\3\2\2\2\u009b\u009c\3\2\2\2") - buf.write("\u009c\u009f\3\2\2\2\u009d\u009b\3\2\2\2\u009e\u0096\3") - buf.write("\2\2\2\u009e\u009f\3\2\2\2\u009f\u00a0\3\2\2\2\u00a0\u00c3") - buf.write("\7\13\2\2\u00a1\u00a2\7\f\2\2\u00a2\u00a3\7\b\2\2\u00a3") - buf.write("\u00a4\5\20\t\2\u00a4\u00a5\7\t\2\2\u00a5\u00a6\5@!\2") - buf.write("\u00a6\u00a7\7\r\2\2\u00a7\u00a8\5@!\2\u00a8\u00c3\3\2") - buf.write("\2\2\u00a9\u00aa\5 \21\2\u00aa\u00ab\5\20\t\2\u00ab\u00ad") - buf.write("\7\16\2\2\u00ac\u00ae\5\34\17\2\u00ad\u00ac\3\2\2\2\u00ad") - buf.write("\u00ae\3\2\2\2\u00ae\u00af\3\2\2\2\u00af\u00b0\7\17\2") - buf.write("\2\u00b0\u00c3\3\2\2\2\u00b1\u00b2\7\20\2\2\u00b2\u00b3") - buf.write("\5\60\31\2\u00b3\u00b4\7\21\2\2\u00b4\u00b5\5\20\t\2\u00b5") - buf.write("\u00b6\7\22\2\2\u00b6\u00b7\5\20\t\t\u00b7\u00c3\3\2\2") - buf.write("\2\u00b8\u00b9\5\n\6\2\u00b9\u00ba\7\21\2\2\u00ba\u00bb") - buf.write("\5\20\t\2\u00bb\u00bc\7\22\2\2\u00bc\u00bd\5\20\t\7\u00bd") - buf.write("\u00c3\3\2\2\2\u00be\u00c3\5D#\2\u00bf\u00c3\5B\"\2\u00c0") - buf.write("\u00c3\5<\37\2\u00c1\u00c3\7#\2\2\u00c2|\3\2\2\2\u00c2") - buf.write("\u0081\3\2\2\2\u00c2\u0083\3\2\2\2\u00c2\u0084\3\2\2\2") - buf.write("\u00c2\u0086\3\2\2\2\u00c2\u008b\3\2\2\2\u00c2\u0095\3") - buf.write("\2\2\2\u00c2\u00a1\3\2\2\2\u00c2\u00a9\3\2\2\2\u00c2\u00b1") - buf.write("\3\2\2\2\u00c2\u00b8\3\2\2\2\u00c2\u00be\3\2\2\2\u00c2") - buf.write("\u00bf\3\2\2\2\u00c2\u00c0\3\2\2\2\u00c2\u00c1\3\2\2\2") - buf.write("\u00c3\u00dd\3\2\2\2\u00c4\u00c5\f\25\2\2\u00c5\u00c6") - buf.write("\t\3\2\2\u00c6\u00dc\5\20\t\26\u00c7\u00c8\f\24\2\2\u00c8") - buf.write("\u00c9\t\4\2\2\u00c9\u00dc\5\20\t\25\u00ca\u00cb\f\23") - buf.write("\2\2\u00cb\u00cc\t\5\2\2\u00cc\u00dc\5\20\t\24\u00cd\u00ce") - buf.write("\f\22\2\2\u00ce\u00cf\t\6\2\2\u00cf\u00dc\5\20\t\23\u00d0") - buf.write("\u00d1\f\b\2\2\u00d1\u00d2\7\23\2\2\u00d2\u00dc\5\20\t") - buf.write("\t\u00d3\u00d4\f\27\2\2\u00d4\u00d5\7\b\2\2\u00d5\u00d6") - buf.write("\5\16\b\2\u00d6\u00d7\7\t\2\2\u00d7\u00dc\3\2\2\2\u00d8") - buf.write("\u00d9\f\n\2\2\u00d9\u00da\7\3\2\2\u00da\u00dc\7\61\2") - buf.write("\2\u00db\u00c4\3\2\2\2\u00db\u00c7\3\2\2\2\u00db\u00ca") - buf.write("\3\2\2\2\u00db\u00cd\3\2\2\2\u00db\u00d0\3\2\2\2\u00db") - buf.write("\u00d3\3\2\2\2\u00db\u00d8\3\2\2\2\u00dc\u00df\3\2\2\2") - buf.write("\u00dd\u00db\3\2\2\2\u00dd\u00de\3\2\2\2\u00de\21\3\2") - buf.write("\2\2\u00df\u00dd\3\2\2\2\u00e0\u00e2\7\24\2\2\u00e1\u00e3") - buf.write("\58\35\2\u00e2\u00e1\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3") - buf.write("\u00e4\3\2\2\2\u00e4\u00e5\7\b\2\2\u00e5\u00e6\5,\27\2") - buf.write("\u00e6\u00e9\7\t\2\2\u00e7\u00e8\7\25\2\2\u00e8\u00ea") - buf.write("\5\66\34\2\u00e9\u00e7\3\2\2\2\u00e9\u00ea\3\2\2\2\u00ea") - buf.write("\u00eb\3\2\2\2\u00eb\u00ec\5@!\2\u00ec\23\3\2\2\2\u00ed") - buf.write("\u00ee\7\26\2\2\u00ee\u00f0\5\6\4\2\u00ef\u00f1\58\35") - buf.write("\2\u00f0\u00ef\3\2\2\2\u00f0\u00f1\3\2\2\2\u00f1\u00f2") - buf.write("\3\2\2\2\u00f2\u00f3\7\b\2\2\u00f3\u00f4\5,\27\2\u00f4") - buf.write("\u00f7\7\t\2\2\u00f5\u00f6\7\25\2\2\u00f6\u00f8\5\66\34") - buf.write("\2\u00f7\u00f5\3\2\2\2\u00f7\u00f8\3\2\2\2\u00f8\u00f9") - buf.write("\3\2\2\2\u00f9\u00fa\5@!\2\u00fa\u010d\3\2\2\2\u00fb\u00fc") - buf.write("\7\27\2\2\u00fc\u00fd\7\30\2\2\u00fd\u00ff\5\4\3\2\u00fe") - buf.write("\u0100\58\35\2\u00ff\u00fe\3\2\2\2\u00ff\u0100\3\2\2\2") - buf.write("\u0100\u010d\3\2\2\2\u0101\u0102\7\30\2\2\u0102\u0104") - buf.write("\5\4\3\2\u0103\u0105\58\35\2\u0104\u0103\3\2\2\2\u0104") - buf.write("\u0105\3\2\2\2\u0105\u0106\3\2\2\2\u0106\u0108\7\16\2") - buf.write("\2\u0107\u0109\5\30\r\2\u0108\u0107\3\2\2\2\u0108\u0109") - buf.write("\3\2\2\2\u0109\u010a\3\2\2\2\u010a\u010b\7\17\2\2\u010b") - buf.write("\u010d\3\2\2\2\u010c\u00ed\3\2\2\2\u010c\u00fb\3\2\2\2") - buf.write("\u010c\u0101\3\2\2\2\u010d\25\3\2\2\2\u010e\u010f\7/\2") - buf.write("\2\u010f\27\3\2\2\2\u0110\u0115\5\32\16\2\u0111\u0112") - buf.write("\7\7\2\2\u0112\u0114\5\32\16\2\u0113\u0111\3\2\2\2\u0114") - buf.write("\u0117\3\2\2\2\u0115\u0113\3\2\2\2\u0115\u0116\3\2\2\2") - buf.write("\u0116\u0119\3\2\2\2\u0117\u0115\3\2\2\2\u0118\u011a\7") - buf.write("\7\2\2\u0119\u0118\3\2\2\2\u0119\u011a\3\2\2\2\u011a\31") - buf.write("\3\2\2\2\u011b\u0127\5\26\f\2\u011c\u011d\7\b\2\2\u011d") - buf.write("\u0122\5\66\34\2\u011e\u011f\7\7\2\2\u011f\u0121\5\66") - buf.write("\34\2\u0120\u011e\3\2\2\2\u0121\u0124\3\2\2\2\u0122\u0120") - buf.write("\3\2\2\2\u0122\u0123\3\2\2\2\u0123\u0125\3\2\2\2\u0124") - buf.write("\u0122\3\2\2\2\u0125\u0126\7\t\2\2\u0126\u0128\3\2\2\2") - buf.write("\u0127\u011c\3\2\2\2\u0127\u0128\3\2\2\2\u0128\33\3\2") - buf.write("\2\2\u0129\u012e\5\36\20\2\u012a\u012b\7\7\2\2\u012b\u012d") - buf.write("\5\36\20\2\u012c\u012a\3\2\2\2\u012d\u0130\3\2\2\2\u012e") - buf.write("\u012c\3\2\2\2\u012e\u012f\3\2\2\2\u012f\u0132\3\2\2\2") - buf.write("\u0130\u012e\3\2\2\2\u0131\u0133\7\7\2\2\u0132\u0131\3") - buf.write("\2\2\2\u0132\u0133\3\2\2\2\u0133\35\3\2\2\2\u0134\u0135") - buf.write("\5$\23\2\u0135\u013b\7\31\2\2\u0136\u0137\7\16\2\2\u0137") - buf.write("\u0138\5\20\t\2\u0138\u0139\7\17\2\2\u0139\u013c\3\2\2") - buf.write("\2\u013a\u013c\5\20\t\2\u013b\u0136\3\2\2\2\u013b\u013a") - buf.write("\3\2\2\2\u013c\37\3\2\2\2\u013d\u013e\t\7\2\2\u013e!\3") - buf.write("\2\2\2\u013f\u0140\7\b\2\2\u0140\u0145\5$\23\2\u0141\u0142") - buf.write("\7\7\2\2\u0142\u0144\5$\23\2\u0143\u0141\3\2\2\2\u0144") - buf.write("\u0147\3\2\2\2\u0145\u0143\3\2\2\2\u0145\u0146\3\2\2\2") - buf.write("\u0146\u0148\3\2\2\2\u0147\u0145\3\2\2\2\u0148\u0149\7") - buf.write("\t\2\2\u0149#\3\2\2\2\u014a\u0156\7\6\2\2\u014b\u014e") - buf.write("\5\b\5\2\u014c\u014d\7\34\2\2\u014d\u014f\5\66\34\2\u014e") - buf.write("\u014c\3\2\2\2\u014e\u014f\3\2\2\2\u014f\u0156\3\2\2\2") - buf.write("\u0150\u0152\5\26\f\2\u0151\u0153\5\"\22\2\u0152\u0151") - buf.write("\3\2\2\2\u0152\u0153\3\2\2\2\u0153\u0156\3\2\2\2\u0154") - buf.write("\u0156\5\"\22\2\u0155\u014a\3\2\2\2\u0155\u014b\3\2\2") - buf.write("\2\u0155\u0150\3\2\2\2\u0155\u0154\3\2\2\2\u0156%\3\2") - buf.write("\2\2\u0157\u0159\5\26\f\2\u0158\u015a\5(\25\2\u0159\u0158") - buf.write("\3\2\2\2\u0159\u015a\3\2\2\2\u015a\'\3\2\2\2\u015b\u015c") - buf.write("\7\b\2\2\u015c\u0161\5*\26\2\u015d\u015e\7\7\2\2\u015e") - buf.write("\u0160\5*\26\2\u015f\u015d\3\2\2\2\u0160\u0163\3\2\2\2") - buf.write("\u0161\u015f\3\2\2\2\u0161\u0162\3\2\2\2\u0162\u0164\3") - buf.write("\2\2\2\u0163\u0161\3\2\2\2\u0164\u0165\7\t\2\2\u0165)") - buf.write("\3\2\2\2\u0166\u0169\5\b\5\2\u0167\u0169\5\26\f\2\u0168") - buf.write("\u0166\3\2\2\2\u0168\u0167\3\2\2\2\u0169+\3\2\2\2\u016a") - buf.write("\u0175\5.\30\2\u016b\u016c\5\60\31\2\u016c\u016d\7\7\2") - buf.write("\2\u016d\u016f\3\2\2\2\u016e\u016b\3\2\2\2\u016f\u0172") - buf.write("\3\2\2\2\u0170\u016e\3\2\2\2\u0170\u0171\3\2\2\2\u0171") - buf.write("\u0173\3\2\2\2\u0172\u0170\3\2\2\2\u0173\u0175\5\62\32") - buf.write("\2\u0174\u016a\3\2\2\2\u0174\u0170\3\2\2\2\u0175-\3\2") - buf.write("\2\2\u0176\u017b\5\60\31\2\u0177\u0178\7\7\2\2\u0178\u017a") - buf.write("\5\60\31\2\u0179\u0177\3\2\2\2\u017a\u017d\3\2\2\2\u017b") - buf.write("\u0179\3\2\2\2\u017b\u017c\3\2\2\2\u017c\u017f\3\2\2\2") - buf.write("\u017d\u017b\3\2\2\2\u017e\u0176\3\2\2\2\u017e\u017f\3") - buf.write("\2\2\2\u017f/\3\2\2\2\u0180\u0183\5\b\5\2\u0181\u0182") - buf.write("\7\34\2\2\u0182\u0184\5\66\34\2\u0183\u0181\3\2\2\2\u0183") - buf.write("\u0184\3\2\2\2\u0184\61\3\2\2\2\u0185\u018a\5\64\33\2") - buf.write("\u0186\u0187\7\7\2\2\u0187\u0189\5\64\33\2\u0188\u0186") - buf.write("\3\2\2\2\u0189\u018c\3\2\2\2\u018a\u0188\3\2\2\2\u018a") - buf.write("\u018b\3\2\2\2\u018b\63\3\2\2\2\u018c\u018a\3\2\2\2\u018d") - buf.write("\u018e\7/\2\2\u018e\u018f\7\21\2\2\u018f\u0190\5\20\t") - buf.write("\2\u0190\65\3\2\2\2\u0191\u0192\7\b\2\2\u0192\u01c5\7") - buf.write("\t\2\2\u0193\u0194\7\b\2\2\u0194\u0195\5\66\34\2\u0195") - buf.write("\u0196\7\t\2\2\u0196\u01c5\3\2\2\2\u0197\u0198\7\b\2\2") - buf.write("\u0198\u0199\5\66\34\2\u0199\u019a\7\7\2\2\u019a\u019b") - buf.write("\7\t\2\2\u019b\u01c5\3\2\2\2\u019c\u019d\7\b\2\2\u019d") - buf.write("\u01a0\5\66\34\2\u019e\u019f\7\7\2\2\u019f\u01a1\5\66") - buf.write("\34\2\u01a0\u019e\3\2\2\2\u01a1\u01a2\3\2\2\2\u01a2\u01a0") - buf.write("\3\2\2\2\u01a2\u01a3\3\2\2\2\u01a3\u01a4\3\2\2\2\u01a4") - buf.write("\u01a5\7\t\2\2\u01a5\u01c5\3\2\2\2\u01a6\u01a7\5\4\3\2") - buf.write("\u01a7\u01a8\58\35\2\u01a8\u01c5\3\2\2\2\u01a9\u01c5\5") - buf.write("\4\3\2\u01aa\u01ab\7\35\2\2\u01ab\u01ac\7\n\2\2\u01ac") - buf.write("\u01ad\5:\36\2\u01ad\u01ae\7\7\2\2\u01ae\u01af\5\66\34") - buf.write("\2\u01af\u01b0\7\13\2\2\u01b0\u01c5\3\2\2\2\u01b1\u01b3") - buf.write("\7\24\2\2\u01b2\u01b4\58\35\2\u01b3\u01b2\3\2\2\2\u01b3") - buf.write("\u01b4\3\2\2\2\u01b4\u01b5\3\2\2\2\u01b5\u01be\7\b\2\2") - buf.write("\u01b6\u01bb\5\66\34\2\u01b7\u01b8\7\7\2\2\u01b8\u01ba") - buf.write("\5\66\34\2\u01b9\u01b7\3\2\2\2\u01ba\u01bd\3\2\2\2\u01bb") - buf.write("\u01b9\3\2\2\2\u01bb\u01bc\3\2\2\2\u01bc\u01bf\3\2\2\2") - buf.write("\u01bd\u01bb\3\2\2\2\u01be\u01b6\3\2\2\2\u01be\u01bf\3") - buf.write("\2\2\2\u01bf\u01c0\3\2\2\2\u01c0\u01c1\7\t\2\2\u01c1\u01c2") - buf.write("\7\25\2\2\u01c2\u01c5\5\66\34\2\u01c3\u01c5\7\6\2\2\u01c4") - buf.write("\u0191\3\2\2\2\u01c4\u0193\3\2\2\2\u01c4\u0197\3\2\2\2") - buf.write("\u01c4\u019c\3\2\2\2\u01c4\u01a6\3\2\2\2\u01c4\u01a9\3") - buf.write("\2\2\2\u01c4\u01aa\3\2\2\2\u01c4\u01b1\3\2\2\2\u01c4\u01c3") - buf.write("\3\2\2\2\u01c5\67\3\2\2\2\u01c6\u01c7\7\n\2\2\u01c7\u01cc") - buf.write("\5\66\34\2\u01c8\u01c9\7\7\2\2\u01c9\u01cb\5\66\34\2\u01ca") - buf.write("\u01c8\3\2\2\2\u01cb\u01ce\3\2\2\2\u01cc\u01ca\3\2\2\2") - buf.write("\u01cc\u01cd\3\2\2\2\u01cd\u01cf\3\2\2\2\u01ce\u01cc\3") - buf.write("\2\2\2\u01cf\u01d0\7\13\2\2\u01d09\3\2\2\2\u01d1\u01d2") - buf.write("\7\b\2\2\u01d2\u01df\7\t\2\2\u01d3\u01d4\7\b\2\2\u01d4") - buf.write("\u01d7\5> \2\u01d5\u01d6\7\7\2\2\u01d6\u01d8\5> \2\u01d7") - buf.write("\u01d5\3\2\2\2\u01d8\u01d9\3\2\2\2\u01d9\u01d7\3\2\2\2") - buf.write("\u01d9\u01da\3\2\2\2\u01da\u01db\3\2\2\2\u01db\u01dc\7") - buf.write("\t\2\2\u01dc\u01df\3\2\2\2\u01dd\u01df\5> \2\u01de\u01d1") - buf.write("\3\2\2\2\u01de\u01d3\3\2\2\2\u01de\u01dd\3\2\2\2\u01df") - buf.write(";\3\2\2\2\u01e0\u01e1\7\36\2\2\u01e1\u01e2\7\n\2\2\u01e2") - buf.write("\u01e3\7/\2\2\u01e3\u01e4\7\13\2\2\u01e4\u01e5\7\n\2\2") - buf.write("\u01e5\u01e6\7\61\2\2\u01e6\u01e7\7\13\2\2\u01e7=\3\2") - buf.write("\2\2\u01e8\u01ef\5<\37\2\u01e9\u01ea\7\b\2\2\u01ea\u01eb") - buf.write("\5> \2\u01eb\u01ec\7\t\2\2\u01ec\u01ef\3\2\2\2\u01ed\u01ef") - buf.write("\7\61\2\2\u01ee\u01e8\3\2\2\2\u01ee\u01e9\3\2\2\2\u01ee") - buf.write("\u01ed\3\2\2\2\u01ef?\3\2\2\2\u01f0\u01f1\7\16\2\2\u01f1") - buf.write("\u01f2\5\20\t\2\u01f2\u01f3\7\17\2\2\u01f3A\3\2\2\2\u01f4") - buf.write("\u01f8\7\60\2\2\u01f5\u01f8\7\61\2\2\u01f6\u01f8\7.\2") - buf.write("\2\u01f7\u01f4\3\2\2\2\u01f7\u01f5\3\2\2\2\u01f7\u01f6") - buf.write("\3\2\2\2\u01f8C\3\2\2\2\u01f9\u01fe\5\4\3\2\u01fa\u01fe") - buf.write("\5\6\4\2\u01fb\u01fe\5\b\5\2\u01fc\u01fe\5\n\6\2\u01fd") - buf.write("\u01f9\3\2\2\2\u01fd\u01fa\3\2\2\2\u01fd\u01fb\3\2\2\2") - buf.write("\u01fd\u01fc\3\2\2\2\u01feE\3\2\2\28JNQZknvz\u0091\u009b") - buf.write("\u009e\u00ad\u00c2\u00db\u00dd\u00e2\u00e9\u00f0\u00f7") - buf.write("\u00ff\u0104\u0108\u010c\u0115\u0119\u0122\u0127\u012e") - buf.write("\u0132\u013b\u0145\u014e\u0152\u0155\u0159\u0161\u0168") - buf.write("\u0170\u0174\u017b\u017e\u0183\u018a\u01a2\u01b3\u01bb") - buf.write("\u01be\u01c4\u01cc\u01d9\u01de\u01ee\u01f7\u01fd") - return buf.getvalue() - - -class RelayParser ( Parser ): - - grammarFileName = "Relay.g4" - - atn = ATNDeserializer().deserialize(serializedATN()) - - decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] - - sharedContextCache = PredictionContextCache() - - literalNames = [ "", "'.'", "'@'", "'%'", "'_'", "','", "'('", - "')'", "'['", "']'", "'if'", "'else'", "'{'", "'}'", - "'let'", "'='", "';'", "';;'", "'fn'", "'->'", "'def'", - "'extern'", "'type'", "'=>'", "'match'", "'match?'", - "':'", "'Tensor'", "'meta'", "'v0.0.4'", "", - "", "", "", "'*'", "'/'", - "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", - "'!='" ] - - symbolicNames = [ "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "SEMVER", "COMMENT", "WS", "LINE_COMMENT", - "QUOTED_STRING", "MUL", "DIV", "ADD", "SUB", "LT", - "GT", "LE", "GE", "EQ", "NE", "BOOL_LIT", "CNAME", - "FLOAT", "NAT", "METADATA" ] - - RULE_prog = 0 - RULE_generalIdent = 1 - RULE_globalVar = 2 - RULE_localVar = 3 - RULE_graphVar = 4 - RULE_exprList = 5 - RULE_callList = 6 - RULE_expr = 7 - RULE_func = 8 - RULE_defn = 9 - RULE_constructorName = 10 - RULE_adtConsDefnList = 11 - RULE_adtConsDefn = 12 - RULE_matchClauseList = 13 - RULE_matchClause = 14 - RULE_matchType = 15 - RULE_patternList = 16 - RULE_pattern = 17 - RULE_adtCons = 18 - RULE_adtConsParamList = 19 - RULE_adtConsParam = 20 - RULE_argList = 21 - RULE_varList = 22 - RULE_var = 23 - RULE_attrSeq = 24 - RULE_attr = 25 - RULE_typeExpr = 26 - RULE_typeParamList = 27 - RULE_shapeList = 28 - RULE_meta = 29 - RULE_shape = 30 - RULE_body = 31 - RULE_scalar = 32 - RULE_ident = 33 - - ruleNames = [ "prog", "generalIdent", "globalVar", "localVar", "graphVar", - "exprList", "callList", "expr", "func", "defn", "constructorName", - "adtConsDefnList", "adtConsDefn", "matchClauseList", - "matchClause", "matchType", "patternList", "pattern", - "adtCons", "adtConsParamList", "adtConsParam", "argList", - "varList", "var", "attrSeq", "attr", "typeExpr", "typeParamList", - "shapeList", "meta", "shape", "body", "scalar", "ident" ] - - EOF = Token.EOF - T__0=1 - T__1=2 - T__2=3 - T__3=4 - T__4=5 - T__5=6 - T__6=7 - T__7=8 - T__8=9 - T__9=10 - T__10=11 - T__11=12 - T__12=13 - T__13=14 - T__14=15 - T__15=16 - T__16=17 - T__17=18 - T__18=19 - T__19=20 - T__20=21 - T__21=22 - T__22=23 - T__23=24 - T__24=25 - T__25=26 - T__26=27 - T__27=28 - SEMVER=29 - COMMENT=30 - WS=31 - LINE_COMMENT=32 - QUOTED_STRING=33 - MUL=34 - DIV=35 - ADD=36 - SUB=37 - LT=38 - GT=39 - LE=40 - GE=41 - EQ=42 - NE=43 - BOOL_LIT=44 - CNAME=45 - FLOAT=46 - NAT=47 - METADATA=48 - - def __init__(self, input:TokenStream, output:TextIO = sys.stdout): - super().__init__(input, output) - self.checkVersion("4.7.2") - self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) - self._predicates = None - - - - - class ProgContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def SEMVER(self): - return self.getToken(RelayParser.SEMVER, 0) - - def EOF(self): - return self.getToken(RelayParser.EOF, 0) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def METADATA(self): - return self.getToken(RelayParser.METADATA, 0) - - def defn(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.DefnContext) - else: - return self.getTypedRuleContext(RelayParser.DefnContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_prog - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitProg" ): - return visitor.visitProg(self) - else: - return visitor.visitChildren(self) - - - - - def prog(self): - - localctx = RelayParser.ProgContext(self, self._ctx, self.state) - self.enterRule(localctx, 0, self.RULE_prog) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 68 - self.match(RelayParser.SEMVER) - self.state = 76 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.EOF, RelayParser.T__19, RelayParser.T__20, RelayParser.T__21, RelayParser.METADATA]: - self.state = 72 - self._errHandler.sync(self) - _la = self._input.LA(1) - while (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__19) | (1 << RelayParser.T__20) | (1 << RelayParser.T__21))) != 0): - self.state = 69 - self.defn() - self.state = 74 - self._errHandler.sync(self) - _la = self._input.LA(1) - - pass - elif token in [RelayParser.T__1, RelayParser.T__2, RelayParser.T__5, RelayParser.T__7, RelayParser.T__9, RelayParser.T__13, RelayParser.T__17, RelayParser.T__23, RelayParser.T__24, RelayParser.T__27, RelayParser.QUOTED_STRING, RelayParser.SUB, RelayParser.BOOL_LIT, RelayParser.CNAME, RelayParser.FLOAT, RelayParser.NAT]: - self.state = 75 - self.expr(0) - pass - else: - raise NoViableAltException(self) - - self.state = 79 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.METADATA: - self.state = 78 - self.match(RelayParser.METADATA) - - - self.state = 81 - self.match(RelayParser.EOF) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class GeneralIdentContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self, i:int=None): - if i is None: - return self.getTokens(RelayParser.CNAME) - else: - return self.getToken(RelayParser.CNAME, i) - - def getRuleIndex(self): - return RelayParser.RULE_generalIdent - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitGeneralIdent" ): - return visitor.visitGeneralIdent(self) - else: - return visitor.visitChildren(self) - - - - - def generalIdent(self): - - localctx = RelayParser.GeneralIdentContext(self, self._ctx, self.state) - self.enterRule(localctx, 2, self.RULE_generalIdent) - try: - self.enterOuterAlt(localctx, 1) - self.state = 83 - self.match(RelayParser.CNAME) - self.state = 88 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,3,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - self.state = 84 - self.match(RelayParser.T__0) - self.state = 85 - self.match(RelayParser.CNAME) - self.state = 90 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,3,self._ctx) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class GlobalVarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def getRuleIndex(self): - return RelayParser.RULE_globalVar - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitGlobalVar" ): - return visitor.visitGlobalVar(self) - else: - return visitor.visitChildren(self) - - - - - def globalVar(self): - - localctx = RelayParser.GlobalVarContext(self, self._ctx, self.state) - self.enterRule(localctx, 4, self.RULE_globalVar) - try: - self.enterOuterAlt(localctx, 1) - self.state = 91 - self.match(RelayParser.T__1) - self.state = 92 - self.match(RelayParser.CNAME) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class LocalVarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def getRuleIndex(self): - return RelayParser.RULE_localVar - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitLocalVar" ): - return visitor.visitLocalVar(self) - else: - return visitor.visitChildren(self) - - - - - def localVar(self): - - localctx = RelayParser.LocalVarContext(self, self._ctx, self.state) - self.enterRule(localctx, 6, self.RULE_localVar) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 94 - self.match(RelayParser.T__2) - self.state = 95 - _la = self._input.LA(1) - if not(_la==RelayParser.T__3 or _la==RelayParser.CNAME): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class GraphVarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def getRuleIndex(self): - return RelayParser.RULE_graphVar - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitGraphVar" ): - return visitor.visitGraphVar(self) - else: - return visitor.visitChildren(self) - - - - - def graphVar(self): - - localctx = RelayParser.GraphVarContext(self, self._ctx, self.state) - self.enterRule(localctx, 8, self.RULE_graphVar) - try: - self.enterOuterAlt(localctx, 1) - self.state = 97 - self.match(RelayParser.T__2) - self.state = 98 - self.match(RelayParser.NAT) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ExprListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_exprList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitExprList" ): - return visitor.visitExprList(self) - else: - return visitor.visitChildren(self) - - - - - def exprList(self): - - localctx = RelayParser.ExprListContext(self, self._ctx, self.state) - self.enterRule(localctx, 10, self.RULE_exprList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 108 - self._errHandler.sync(self) - _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__1) | (1 << RelayParser.T__2) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__9) | (1 << RelayParser.T__13) | (1 << RelayParser.T__17) | (1 << RelayParser.T__23) | (1 << RelayParser.T__24) | (1 << RelayParser.T__27) | (1 << RelayParser.QUOTED_STRING) | (1 << RelayParser.SUB) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.CNAME) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT))) != 0): - self.state = 100 - self.expr(0) - self.state = 105 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 101 - self.match(RelayParser.T__4) - self.state = 102 - self.expr(0) - self.state = 107 - self._errHandler.sync(self) - _la = self._input.LA(1) - - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class CallListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_callList - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class CallWithAttrContext(CallListContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.CallListContext - super().__init__(parser) - self.copyFrom(ctx) - - def attrSeq(self): - return self.getTypedRuleContext(RelayParser.AttrSeqContext,0) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitCallWithAttr" ): - return visitor.visitCallWithAttr(self) - else: - return visitor.visitChildren(self) - - - class CallNoAttrContext(CallListContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.CallListContext - super().__init__(parser) - self.copyFrom(ctx) - - def exprList(self): - return self.getTypedRuleContext(RelayParser.ExprListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitCallNoAttr" ): - return visitor.visitCallNoAttr(self) - else: - return visitor.visitChildren(self) - - - - def callList(self): - - localctx = RelayParser.CallListContext(self, self._ctx, self.state) - self.enterRule(localctx, 12, self.RULE_callList) - try: - self.state = 120 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,7,self._ctx) - if la_ == 1: - localctx = RelayParser.CallNoAttrContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 110 - self.exprList() - pass - - elif la_ == 2: - localctx = RelayParser.CallWithAttrContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 116 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,6,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - self.state = 111 - self.expr(0) - self.state = 112 - self.match(RelayParser.T__4) - self.state = 118 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,6,self._ctx) - - self.state = 119 - self.attrSeq() - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ExprContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_expr - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - class FuncExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def func(self): - return self.getTypedRuleContext(RelayParser.FuncContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitFuncExpr" ): - return visitor.visitFuncExpr(self) - else: - return visitor.visitChildren(self) - - - class MetaExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def meta(self): - return self.getTypedRuleContext(RelayParser.MetaContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMetaExpr" ): - return visitor.visitMetaExpr(self) - else: - return visitor.visitChildren(self) - - - class MatchContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def matchType(self): - return self.getTypedRuleContext(RelayParser.MatchTypeContext,0) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - def matchClauseList(self): - return self.getTypedRuleContext(RelayParser.MatchClauseListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMatch" ): - return visitor.visitMatch(self) - else: - return visitor.visitChildren(self) - - - class TensorContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTensor" ): - return visitor.visitTensor(self) - else: - return visitor.visitChildren(self) - - - class GraphContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def graphVar(self): - return self.getTypedRuleContext(RelayParser.GraphVarContext,0) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitGraph" ): - return visitor.visitGraph(self) - else: - return visitor.visitChildren(self) - - - class IdentExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def ident(self): - return self.getTypedRuleContext(RelayParser.IdentContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIdentExpr" ): - return visitor.visitIdentExpr(self) - else: - return visitor.visitChildren(self) - - - class StringExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def QUOTED_STRING(self): - return self.getToken(RelayParser.QUOTED_STRING, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitStringExpr" ): - return visitor.visitStringExpr(self) - else: - return visitor.visitChildren(self) - - - class CallContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - def callList(self): - return self.getTypedRuleContext(RelayParser.CallListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitCall" ): - return visitor.visitCall(self) - else: - return visitor.visitChildren(self) - - - class NegContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def SUB(self): - return self.getToken(RelayParser.SUB, 0) - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitNeg" ): - return visitor.visitNeg(self) - else: - return visitor.visitChildren(self) - - - class TupleContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTuple" ): - return visitor.visitTuple(self) - else: - return visitor.visitChildren(self) - - - class ParenContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitParen" ): - return visitor.visitParen(self) - else: - return visitor.visitChildren(self) - - - class ScalarExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def scalar(self): - return self.getTypedRuleContext(RelayParser.ScalarContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitScalarExpr" ): - return visitor.visitScalarExpr(self) - else: - return visitor.visitChildren(self) - - - class LetContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def var(self): - return self.getTypedRuleContext(RelayParser.VarContext,0) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitLet" ): - return visitor.visitLet(self) - else: - return visitor.visitChildren(self) - - - class ProjectionContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitProjection" ): - return visitor.visitProjection(self) - else: - return visitor.visitChildren(self) - - - class IfElseContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - def body(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.BodyContext) - else: - return self.getTypedRuleContext(RelayParser.BodyContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIfElse" ): - return visitor.visitIfElse(self) - else: - return visitor.visitChildren(self) - - - class BinOpContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.op = None # Token - self.copyFrom(ctx) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - def MUL(self): - return self.getToken(RelayParser.MUL, 0) - def DIV(self): - return self.getToken(RelayParser.DIV, 0) - def ADD(self): - return self.getToken(RelayParser.ADD, 0) - def SUB(self): - return self.getToken(RelayParser.SUB, 0) - def LT(self): - return self.getToken(RelayParser.LT, 0) - def GT(self): - return self.getToken(RelayParser.GT, 0) - def LE(self): - return self.getToken(RelayParser.LE, 0) - def GE(self): - return self.getToken(RelayParser.GE, 0) - def EQ(self): - return self.getToken(RelayParser.EQ, 0) - def NE(self): - return self.getToken(RelayParser.NE, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitBinOp" ): - return visitor.visitBinOp(self) - else: - return visitor.visitChildren(self) - - - - def expr(self, _p:int=0): - _parentctx = self._ctx - _parentState = self.state - localctx = RelayParser.ExprContext(self, self._ctx, _parentState) - _prevctx = localctx - _startState = 14 - self.enterRecursionRule(localctx, 14, self.RULE_expr, _p) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 192 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,12,self._ctx) - if la_ == 1: - localctx = RelayParser.ParenContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - - self.state = 123 - self.match(RelayParser.T__5) - self.state = 124 - self.expr(0) - self.state = 125 - self.match(RelayParser.T__6) - pass - - elif la_ == 2: - localctx = RelayParser.NegContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 127 - self.match(RelayParser.SUB) - self.state = 128 - self.expr(20) - pass - - elif la_ == 3: - localctx = RelayParser.FuncExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 129 - self.func() - pass - - elif la_ == 4: - localctx = RelayParser.TupleContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 130 - self.match(RelayParser.T__5) - self.state = 131 - self.match(RelayParser.T__6) - pass - - elif la_ == 5: - localctx = RelayParser.TupleContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 132 - self.match(RelayParser.T__5) - self.state = 133 - self.expr(0) - self.state = 134 - self.match(RelayParser.T__4) - self.state = 135 - self.match(RelayParser.T__6) - pass - - elif la_ == 6: - localctx = RelayParser.TupleContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 137 - self.match(RelayParser.T__5) - self.state = 138 - self.expr(0) - self.state = 141 - self._errHandler.sync(self) - _la = self._input.LA(1) - while True: - self.state = 139 - self.match(RelayParser.T__4) - self.state = 140 - self.expr(0) - self.state = 143 - self._errHandler.sync(self) - _la = self._input.LA(1) - if not (_la==RelayParser.T__4): - break - - self.state = 145 - self.match(RelayParser.T__6) - pass - - elif la_ == 7: - localctx = RelayParser.TensorContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 147 - self.match(RelayParser.T__7) - self.state = 156 - self._errHandler.sync(self) - _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__1) | (1 << RelayParser.T__2) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__9) | (1 << RelayParser.T__13) | (1 << RelayParser.T__17) | (1 << RelayParser.T__23) | (1 << RelayParser.T__24) | (1 << RelayParser.T__27) | (1 << RelayParser.QUOTED_STRING) | (1 << RelayParser.SUB) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.CNAME) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT))) != 0): - self.state = 148 - self.expr(0) - self.state = 153 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 149 - self.match(RelayParser.T__4) - self.state = 150 - self.expr(0) - self.state = 155 - self._errHandler.sync(self) - _la = self._input.LA(1) - - - - self.state = 158 - self.match(RelayParser.T__8) - pass - - elif la_ == 8: - localctx = RelayParser.IfElseContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 159 - self.match(RelayParser.T__9) - self.state = 160 - self.match(RelayParser.T__5) - self.state = 161 - self.expr(0) - self.state = 162 - self.match(RelayParser.T__6) - self.state = 163 - self.body() - self.state = 164 - self.match(RelayParser.T__10) - self.state = 165 - self.body() - pass - - elif la_ == 9: - localctx = RelayParser.MatchContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 167 - self.matchType() - self.state = 168 - self.expr(0) - self.state = 169 - self.match(RelayParser.T__11) - self.state = 171 - self._errHandler.sync(self) - _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__2) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.CNAME))) != 0): - self.state = 170 - self.matchClauseList() - - - self.state = 173 - self.match(RelayParser.T__12) - pass - - elif la_ == 10: - localctx = RelayParser.LetContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 175 - self.match(RelayParser.T__13) - self.state = 176 - self.var() - self.state = 177 - self.match(RelayParser.T__14) - self.state = 178 - self.expr(0) - self.state = 179 - self.match(RelayParser.T__15) - self.state = 180 - self.expr(7) - pass - - elif la_ == 11: - localctx = RelayParser.GraphContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 182 - self.graphVar() - self.state = 183 - self.match(RelayParser.T__14) - self.state = 184 - self.expr(0) - self.state = 185 - self.match(RelayParser.T__15) - self.state = 186 - self.expr(5) - pass - - elif la_ == 12: - localctx = RelayParser.IdentExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 188 - self.ident() - pass - - elif la_ == 13: - localctx = RelayParser.ScalarExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 189 - self.scalar() - pass - - elif la_ == 14: - localctx = RelayParser.MetaExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 190 - self.meta() - pass - - elif la_ == 15: - localctx = RelayParser.StringExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 191 - self.match(RelayParser.QUOTED_STRING) - pass - - - self._ctx.stop = self._input.LT(-1) - self.state = 219 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,14,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - if self._parseListeners is not None: - self.triggerExitRuleEvent() - _prevctx = localctx - self.state = 217 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,13,self._ctx) - if la_ == 1: - localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 194 - if not self.precpred(self._ctx, 19): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 19)") - self.state = 195 - localctx.op = self._input.LT(1) - _la = self._input.LA(1) - if not(_la==RelayParser.MUL or _la==RelayParser.DIV): - localctx.op = self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 196 - self.expr(20) - pass - - elif la_ == 2: - localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 197 - if not self.precpred(self._ctx, 18): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 18)") - self.state = 198 - localctx.op = self._input.LT(1) - _la = self._input.LA(1) - if not(_la==RelayParser.ADD or _la==RelayParser.SUB): - localctx.op = self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 199 - self.expr(19) - pass - - elif la_ == 3: - localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 200 - if not self.precpred(self._ctx, 17): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 17)") - self.state = 201 - localctx.op = self._input.LT(1) - _la = self._input.LA(1) - if not((((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.LT) | (1 << RelayParser.GT) | (1 << RelayParser.LE) | (1 << RelayParser.GE))) != 0)): - localctx.op = self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 202 - self.expr(18) - pass - - elif la_ == 4: - localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 203 - if not self.precpred(self._ctx, 16): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") - self.state = 204 - localctx.op = self._input.LT(1) - _la = self._input.LA(1) - if not(_la==RelayParser.EQ or _la==RelayParser.NE): - localctx.op = self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 205 - self.expr(17) - pass - - elif la_ == 5: - localctx = RelayParser.LetContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 206 - if not self.precpred(self._ctx, 6): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 6)") - self.state = 207 - self.match(RelayParser.T__16) - self.state = 208 - self.expr(7) - pass - - elif la_ == 6: - localctx = RelayParser.CallContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 209 - if not self.precpred(self._ctx, 21): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 21)") - self.state = 210 - self.match(RelayParser.T__5) - self.state = 211 - self.callList() - self.state = 212 - self.match(RelayParser.T__6) - pass - - elif la_ == 7: - localctx = RelayParser.ProjectionContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 214 - if not self.precpred(self._ctx, 8): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 8)") - self.state = 215 - self.match(RelayParser.T__0) - self.state = 216 - self.match(RelayParser.NAT) - pass - - - self.state = 221 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,14,self._ctx) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.unrollRecursionContexts(_parentctx) - return localctx - - - class FuncContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def argList(self): - return self.getTypedRuleContext(RelayParser.ArgListContext,0) - - - def body(self): - return self.getTypedRuleContext(RelayParser.BodyContext,0) - - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_func - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitFunc" ): - return visitor.visitFunc(self) - else: - return visitor.visitChildren(self) - - - - - def func(self): - - localctx = RelayParser.FuncContext(self, self._ctx, self.state) - self.enterRule(localctx, 16, self.RULE_func) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 222 - self.match(RelayParser.T__17) - self.state = 224 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 223 - self.typeParamList() - - - self.state = 226 - self.match(RelayParser.T__5) - self.state = 227 - self.argList() - self.state = 228 - self.match(RelayParser.T__6) - self.state = 231 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__18: - self.state = 229 - self.match(RelayParser.T__18) - self.state = 230 - self.typeExpr() - - - self.state = 233 - self.body() - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class DefnContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_defn - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class ExternAdtDefnContext(DefnContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.DefnContext - super().__init__(parser) - self.copyFrom(ctx) - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitExternAdtDefn" ): - return visitor.visitExternAdtDefn(self) - else: - return visitor.visitChildren(self) - - - class FuncDefnContext(DefnContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.DefnContext - super().__init__(parser) - self.copyFrom(ctx) - - def globalVar(self): - return self.getTypedRuleContext(RelayParser.GlobalVarContext,0) - - def argList(self): - return self.getTypedRuleContext(RelayParser.ArgListContext,0) - - def body(self): - return self.getTypedRuleContext(RelayParser.BodyContext,0) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitFuncDefn" ): - return visitor.visitFuncDefn(self) - else: - return visitor.visitChildren(self) - - - class AdtDefnContext(DefnContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.DefnContext - super().__init__(parser) - self.copyFrom(ctx) - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - def adtConsDefnList(self): - return self.getTypedRuleContext(RelayParser.AdtConsDefnListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtDefn" ): - return visitor.visitAdtDefn(self) - else: - return visitor.visitChildren(self) - - - - def defn(self): - - localctx = RelayParser.DefnContext(self, self._ctx, self.state) - self.enterRule(localctx, 18, self.RULE_defn) - self._la = 0 # Token type - try: - self.state = 266 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__19]: - localctx = RelayParser.FuncDefnContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 235 - self.match(RelayParser.T__19) - self.state = 236 - self.globalVar() - self.state = 238 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 237 - self.typeParamList() - - - self.state = 240 - self.match(RelayParser.T__5) - self.state = 241 - self.argList() - self.state = 242 - self.match(RelayParser.T__6) - self.state = 245 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__18: - self.state = 243 - self.match(RelayParser.T__18) - self.state = 244 - self.typeExpr() - - - self.state = 247 - self.body() - pass - elif token in [RelayParser.T__20]: - localctx = RelayParser.ExternAdtDefnContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 249 - self.match(RelayParser.T__20) - self.state = 250 - self.match(RelayParser.T__21) - self.state = 251 - self.generalIdent() - self.state = 253 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 252 - self.typeParamList() - - - pass - elif token in [RelayParser.T__21]: - localctx = RelayParser.AdtDefnContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 255 - self.match(RelayParser.T__21) - self.state = 256 - self.generalIdent() - self.state = 258 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 257 - self.typeParamList() - - - self.state = 260 - self.match(RelayParser.T__11) - self.state = 262 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.CNAME: - self.state = 261 - self.adtConsDefnList() - - - self.state = 264 - self.match(RelayParser.T__12) - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ConstructorNameContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def getRuleIndex(self): - return RelayParser.RULE_constructorName - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitConstructorName" ): - return visitor.visitConstructorName(self) - else: - return visitor.visitChildren(self) - - - - - def constructorName(self): - - localctx = RelayParser.ConstructorNameContext(self, self._ctx, self.state) - self.enterRule(localctx, 20, self.RULE_constructorName) - try: - self.enterOuterAlt(localctx, 1) - self.state = 268 - self.match(RelayParser.CNAME) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsDefnListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def adtConsDefn(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.AdtConsDefnContext) - else: - return self.getTypedRuleContext(RelayParser.AdtConsDefnContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_adtConsDefnList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtConsDefnList" ): - return visitor.visitAdtConsDefnList(self) - else: - return visitor.visitChildren(self) - - - - - def adtConsDefnList(self): - - localctx = RelayParser.AdtConsDefnListContext(self, self._ctx, self.state) - self.enterRule(localctx, 22, self.RULE_adtConsDefnList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 270 - self.adtConsDefn() - self.state = 275 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,23,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - self.state = 271 - self.match(RelayParser.T__4) - self.state = 272 - self.adtConsDefn() - self.state = 277 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,23,self._ctx) - - self.state = 279 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__4: - self.state = 278 - self.match(RelayParser.T__4) - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsDefnContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def constructorName(self): - return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) - - - def typeExpr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.TypeExprContext) - else: - return self.getTypedRuleContext(RelayParser.TypeExprContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_adtConsDefn - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtConsDefn" ): - return visitor.visitAdtConsDefn(self) - else: - return visitor.visitChildren(self) - - - - - def adtConsDefn(self): - - localctx = RelayParser.AdtConsDefnContext(self, self._ctx, self.state) - self.enterRule(localctx, 24, self.RULE_adtConsDefn) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 281 - self.constructorName() - self.state = 293 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__5: - self.state = 282 - self.match(RelayParser.T__5) - self.state = 283 - self.typeExpr() - self.state = 288 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 284 - self.match(RelayParser.T__4) - self.state = 285 - self.typeExpr() - self.state = 290 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 291 - self.match(RelayParser.T__6) - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class MatchClauseListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def matchClause(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.MatchClauseContext) - else: - return self.getTypedRuleContext(RelayParser.MatchClauseContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_matchClauseList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMatchClauseList" ): - return visitor.visitMatchClauseList(self) - else: - return visitor.visitChildren(self) - - - - - def matchClauseList(self): - - localctx = RelayParser.MatchClauseListContext(self, self._ctx, self.state) - self.enterRule(localctx, 26, self.RULE_matchClauseList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 295 - self.matchClause() - self.state = 300 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,27,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - self.state = 296 - self.match(RelayParser.T__4) - self.state = 297 - self.matchClause() - self.state = 302 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,27,self._ctx) - - self.state = 304 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__4: - self.state = 303 - self.match(RelayParser.T__4) - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class MatchClauseContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def pattern(self): - return self.getTypedRuleContext(RelayParser.PatternContext,0) - - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_matchClause - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMatchClause" ): - return visitor.visitMatchClause(self) - else: - return visitor.visitChildren(self) - - - - - def matchClause(self): - - localctx = RelayParser.MatchClauseContext(self, self._ctx, self.state) - self.enterRule(localctx, 28, self.RULE_matchClause) - try: - self.enterOuterAlt(localctx, 1) - self.state = 306 - self.pattern() - self.state = 307 - self.match(RelayParser.T__22) - self.state = 313 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__11]: - self.state = 308 - self.match(RelayParser.T__11) - self.state = 309 - self.expr(0) - self.state = 310 - self.match(RelayParser.T__12) - pass - elif token in [RelayParser.T__1, RelayParser.T__2, RelayParser.T__5, RelayParser.T__7, RelayParser.T__9, RelayParser.T__13, RelayParser.T__17, RelayParser.T__23, RelayParser.T__24, RelayParser.T__27, RelayParser.QUOTED_STRING, RelayParser.SUB, RelayParser.BOOL_LIT, RelayParser.CNAME, RelayParser.FLOAT, RelayParser.NAT]: - self.state = 312 - self.expr(0) - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class MatchTypeContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_matchType - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMatchType" ): - return visitor.visitMatchType(self) - else: - return visitor.visitChildren(self) - - - - - def matchType(self): - - localctx = RelayParser.MatchTypeContext(self, self._ctx, self.state) - self.enterRule(localctx, 30, self.RULE_matchType) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 315 - _la = self._input.LA(1) - if not(_la==RelayParser.T__23 or _la==RelayParser.T__24): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class PatternListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def pattern(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.PatternContext) - else: - return self.getTypedRuleContext(RelayParser.PatternContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_patternList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitPatternList" ): - return visitor.visitPatternList(self) - else: - return visitor.visitChildren(self) - - - - - def patternList(self): - - localctx = RelayParser.PatternListContext(self, self._ctx, self.state) - self.enterRule(localctx, 32, self.RULE_patternList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 317 - self.match(RelayParser.T__5) - self.state = 318 - self.pattern() - self.state = 323 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 319 - self.match(RelayParser.T__4) - self.state = 320 - self.pattern() - self.state = 325 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 326 - self.match(RelayParser.T__6) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class PatternContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_pattern - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class WildcardPatternContext(PatternContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext - super().__init__(parser) - self.copyFrom(ctx) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitWildcardPattern" ): - return visitor.visitWildcardPattern(self) - else: - return visitor.visitChildren(self) - - - class ConstructorPatternContext(PatternContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext - super().__init__(parser) - self.copyFrom(ctx) - - def constructorName(self): - return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) - - def patternList(self): - return self.getTypedRuleContext(RelayParser.PatternListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitConstructorPattern" ): - return visitor.visitConstructorPattern(self) - else: - return visitor.visitChildren(self) - - - class TuplePatternContext(PatternContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext - super().__init__(parser) - self.copyFrom(ctx) - - def patternList(self): - return self.getTypedRuleContext(RelayParser.PatternListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTuplePattern" ): - return visitor.visitTuplePattern(self) - else: - return visitor.visitChildren(self) - - - class VarPatternContext(PatternContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext - super().__init__(parser) - self.copyFrom(ctx) - - def localVar(self): - return self.getTypedRuleContext(RelayParser.LocalVarContext,0) - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitVarPattern" ): - return visitor.visitVarPattern(self) - else: - return visitor.visitChildren(self) - - - - def pattern(self): - - localctx = RelayParser.PatternContext(self, self._ctx, self.state) - self.enterRule(localctx, 34, self.RULE_pattern) - self._la = 0 # Token type - try: - self.state = 339 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__3]: - localctx = RelayParser.WildcardPatternContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 328 - self.match(RelayParser.T__3) - pass - elif token in [RelayParser.T__2]: - localctx = RelayParser.VarPatternContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 329 - self.localVar() - self.state = 332 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__25: - self.state = 330 - self.match(RelayParser.T__25) - self.state = 331 - self.typeExpr() - - - pass - elif token in [RelayParser.CNAME]: - localctx = RelayParser.ConstructorPatternContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 334 - self.constructorName() - self.state = 336 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__5: - self.state = 335 - self.patternList() - - - pass - elif token in [RelayParser.T__5]: - localctx = RelayParser.TuplePatternContext(self, localctx) - self.enterOuterAlt(localctx, 4) - self.state = 338 - self.patternList() - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def constructorName(self): - return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) - - - def adtConsParamList(self): - return self.getTypedRuleContext(RelayParser.AdtConsParamListContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_adtCons - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtCons" ): - return visitor.visitAdtCons(self) - else: - return visitor.visitChildren(self) - - - - - def adtCons(self): - - localctx = RelayParser.AdtConsContext(self, self._ctx, self.state) - self.enterRule(localctx, 36, self.RULE_adtCons) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 341 - self.constructorName() - self.state = 343 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__5: - self.state = 342 - self.adtConsParamList() - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsParamListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def adtConsParam(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.AdtConsParamContext) - else: - return self.getTypedRuleContext(RelayParser.AdtConsParamContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_adtConsParamList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtConsParamList" ): - return visitor.visitAdtConsParamList(self) - else: - return visitor.visitChildren(self) - - - - - def adtConsParamList(self): - - localctx = RelayParser.AdtConsParamListContext(self, self._ctx, self.state) - self.enterRule(localctx, 38, self.RULE_adtConsParamList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 345 - self.match(RelayParser.T__5) - self.state = 346 - self.adtConsParam() - self.state = 351 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 347 - self.match(RelayParser.T__4) - self.state = 348 - self.adtConsParam() - self.state = 353 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 354 - self.match(RelayParser.T__6) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsParamContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def localVar(self): - return self.getTypedRuleContext(RelayParser.LocalVarContext,0) - - - def constructorName(self): - return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_adtConsParam - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtConsParam" ): - return visitor.visitAdtConsParam(self) - else: - return visitor.visitChildren(self) - - - - - def adtConsParam(self): - - localctx = RelayParser.AdtConsParamContext(self, self._ctx, self.state) - self.enterRule(localctx, 40, self.RULE_adtConsParam) - try: - self.state = 358 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__2]: - self.enterOuterAlt(localctx, 1) - self.state = 356 - self.localVar() - pass - elif token in [RelayParser.CNAME]: - self.enterOuterAlt(localctx, 2) - self.state = 357 - self.constructorName() - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ArgListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_argList - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class ArgNoAttrContext(ArgListContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ArgListContext - super().__init__(parser) - self.copyFrom(ctx) - - def varList(self): - return self.getTypedRuleContext(RelayParser.VarListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitArgNoAttr" ): - return visitor.visitArgNoAttr(self) - else: - return visitor.visitChildren(self) - - - class ArgWithAttrContext(ArgListContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ArgListContext - super().__init__(parser) - self.copyFrom(ctx) - - def attrSeq(self): - return self.getTypedRuleContext(RelayParser.AttrSeqContext,0) - - def var(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.VarContext) - else: - return self.getTypedRuleContext(RelayParser.VarContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitArgWithAttr" ): - return visitor.visitArgWithAttr(self) - else: - return visitor.visitChildren(self) - - - - def argList(self): - - localctx = RelayParser.ArgListContext(self, self._ctx, self.state) - self.enterRule(localctx, 42, self.RULE_argList) - self._la = 0 # Token type - try: - self.state = 370 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,38,self._ctx) - if la_ == 1: - localctx = RelayParser.ArgNoAttrContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 360 - self.varList() - pass - - elif la_ == 2: - localctx = RelayParser.ArgWithAttrContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 366 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__2: - self.state = 361 - self.var() - self.state = 362 - self.match(RelayParser.T__4) - self.state = 368 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 369 - self.attrSeq() - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class VarListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def var(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.VarContext) - else: - return self.getTypedRuleContext(RelayParser.VarContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_varList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitVarList" ): - return visitor.visitVarList(self) - else: - return visitor.visitChildren(self) - - - - - def varList(self): - - localctx = RelayParser.VarListContext(self, self._ctx, self.state) - self.enterRule(localctx, 44, self.RULE_varList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 380 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__2: - self.state = 372 - self.var() - self.state = 377 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 373 - self.match(RelayParser.T__4) - self.state = 374 - self.var() - self.state = 379 - self._errHandler.sync(self) - _la = self._input.LA(1) - - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class VarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def localVar(self): - return self.getTypedRuleContext(RelayParser.LocalVarContext,0) - - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_var - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitVar" ): - return visitor.visitVar(self) - else: - return visitor.visitChildren(self) - - - - - def var(self): - - localctx = RelayParser.VarContext(self, self._ctx, self.state) - self.enterRule(localctx, 46, self.RULE_var) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 382 - self.localVar() - self.state = 385 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__25: - self.state = 383 - self.match(RelayParser.T__25) - self.state = 384 - self.typeExpr() - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AttrSeqContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def attr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.AttrContext) - else: - return self.getTypedRuleContext(RelayParser.AttrContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_attrSeq - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAttrSeq" ): - return visitor.visitAttrSeq(self) - else: - return visitor.visitChildren(self) - - - - - def attrSeq(self): - - localctx = RelayParser.AttrSeqContext(self, self._ctx, self.state) - self.enterRule(localctx, 48, self.RULE_attrSeq) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 387 - self.attr() - self.state = 392 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 388 - self.match(RelayParser.T__4) - self.state = 389 - self.attr() - self.state = 394 - self._errHandler.sync(self) - _la = self._input.LA(1) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AttrContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_attr - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAttr" ): - return visitor.visitAttr(self) - else: - return visitor.visitChildren(self) - - - - - def attr(self): - - localctx = RelayParser.AttrContext(self, self._ctx, self.state) - self.enterRule(localctx, 50, self.RULE_attr) - try: - self.enterOuterAlt(localctx, 1) - self.state = 395 - self.match(RelayParser.CNAME) - self.state = 396 - self.match(RelayParser.T__14) - self.state = 397 - self.expr(0) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class TypeExprContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_typeExpr - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class TypeParenContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTypeParen" ): - return visitor.visitTypeParen(self) - else: - return visitor.visitChildren(self) - - - class TupleTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def typeExpr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.TypeExprContext) - else: - return self.getTypedRuleContext(RelayParser.TypeExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTupleType" ): - return visitor.visitTupleType(self) - else: - return visitor.visitChildren(self) - - - class TypeCallTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTypeCallType" ): - return visitor.visitTypeCallType(self) - else: - return visitor.visitChildren(self) - - - class TypeIdentTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTypeIdentType" ): - return visitor.visitTypeIdentType(self) - else: - return visitor.visitChildren(self) - - - class IncompleteTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIncompleteType" ): - return visitor.visitIncompleteType(self) - else: - return visitor.visitChildren(self) - - - class TensorTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def shapeList(self): - return self.getTypedRuleContext(RelayParser.ShapeListContext,0) - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTensorType" ): - return visitor.visitTensorType(self) - else: - return visitor.visitChildren(self) - - - class FuncTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def typeExpr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.TypeExprContext) - else: - return self.getTypedRuleContext(RelayParser.TypeExprContext,i) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitFuncType" ): - return visitor.visitFuncType(self) - else: - return visitor.visitChildren(self) - - - - def typeExpr(self): - - localctx = RelayParser.TypeExprContext(self, self._ctx, self.state) - self.enterRule(localctx, 52, self.RULE_typeExpr) - self._la = 0 # Token type - try: - self.state = 450 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,47,self._ctx) - if la_ == 1: - localctx = RelayParser.TupleTypeContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 399 - self.match(RelayParser.T__5) - self.state = 400 - self.match(RelayParser.T__6) - pass - - elif la_ == 2: - localctx = RelayParser.TypeParenContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 401 - self.match(RelayParser.T__5) - self.state = 402 - self.typeExpr() - self.state = 403 - self.match(RelayParser.T__6) - pass - - elif la_ == 3: - localctx = RelayParser.TupleTypeContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 405 - self.match(RelayParser.T__5) - self.state = 406 - self.typeExpr() - self.state = 407 - self.match(RelayParser.T__4) - self.state = 408 - self.match(RelayParser.T__6) - pass - - elif la_ == 4: - localctx = RelayParser.TupleTypeContext(self, localctx) - self.enterOuterAlt(localctx, 4) - self.state = 410 - self.match(RelayParser.T__5) - self.state = 411 - self.typeExpr() - self.state = 414 - self._errHandler.sync(self) - _la = self._input.LA(1) - while True: - self.state = 412 - self.match(RelayParser.T__4) - self.state = 413 - self.typeExpr() - self.state = 416 - self._errHandler.sync(self) - _la = self._input.LA(1) - if not (_la==RelayParser.T__4): - break - - self.state = 418 - self.match(RelayParser.T__6) - pass - - elif la_ == 5: - localctx = RelayParser.TypeCallTypeContext(self, localctx) - self.enterOuterAlt(localctx, 5) - self.state = 420 - self.generalIdent() - self.state = 421 - self.typeParamList() - pass - - elif la_ == 6: - localctx = RelayParser.TypeIdentTypeContext(self, localctx) - self.enterOuterAlt(localctx, 6) - self.state = 423 - self.generalIdent() - pass - - elif la_ == 7: - localctx = RelayParser.TensorTypeContext(self, localctx) - self.enterOuterAlt(localctx, 7) - self.state = 424 - self.match(RelayParser.T__26) - self.state = 425 - self.match(RelayParser.T__7) - self.state = 426 - self.shapeList() - self.state = 427 - self.match(RelayParser.T__4) - self.state = 428 - self.typeExpr() - self.state = 429 - self.match(RelayParser.T__8) - pass - - elif la_ == 8: - localctx = RelayParser.FuncTypeContext(self, localctx) - self.enterOuterAlt(localctx, 8) - self.state = 431 - self.match(RelayParser.T__17) - self.state = 433 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 432 - self.typeParamList() - - - self.state = 435 - self.match(RelayParser.T__5) - self.state = 444 - self._errHandler.sync(self) - _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__17) | (1 << RelayParser.T__26) | (1 << RelayParser.CNAME))) != 0): - self.state = 436 - self.typeExpr() - self.state = 441 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 437 - self.match(RelayParser.T__4) - self.state = 438 - self.typeExpr() - self.state = 443 - self._errHandler.sync(self) - _la = self._input.LA(1) - - - - self.state = 446 - self.match(RelayParser.T__6) - self.state = 447 - self.match(RelayParser.T__18) - self.state = 448 - self.typeExpr() - pass - - elif la_ == 9: - localctx = RelayParser.IncompleteTypeContext(self, localctx) - self.enterOuterAlt(localctx, 9) - self.state = 449 - self.match(RelayParser.T__3) - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class TypeParamListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def typeExpr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.TypeExprContext) - else: - return self.getTypedRuleContext(RelayParser.TypeExprContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_typeParamList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTypeParamList" ): - return visitor.visitTypeParamList(self) - else: - return visitor.visitChildren(self) - - - - - def typeParamList(self): - - localctx = RelayParser.TypeParamListContext(self, self._ctx, self.state) - self.enterRule(localctx, 54, self.RULE_typeParamList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 452 - self.match(RelayParser.T__7) - self.state = 453 - self.typeExpr() - self.state = 458 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 454 - self.match(RelayParser.T__4) - self.state = 455 - self.typeExpr() - self.state = 460 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 461 - self.match(RelayParser.T__8) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ShapeListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def shape(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ShapeContext) - else: - return self.getTypedRuleContext(RelayParser.ShapeContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_shapeList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitShapeList" ): - return visitor.visitShapeList(self) - else: - return visitor.visitChildren(self) - - - - - def shapeList(self): - - localctx = RelayParser.ShapeListContext(self, self._ctx, self.state) - self.enterRule(localctx, 56, self.RULE_shapeList) - self._la = 0 # Token type - try: - self.state = 476 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,50,self._ctx) - if la_ == 1: - self.enterOuterAlt(localctx, 1) - self.state = 463 - self.match(RelayParser.T__5) - self.state = 464 - self.match(RelayParser.T__6) - pass - - elif la_ == 2: - self.enterOuterAlt(localctx, 2) - self.state = 465 - self.match(RelayParser.T__5) - self.state = 466 - self.shape() - self.state = 469 - self._errHandler.sync(self) - _la = self._input.LA(1) - while True: - self.state = 467 - self.match(RelayParser.T__4) - self.state = 468 - self.shape() - self.state = 471 - self._errHandler.sync(self) - _la = self._input.LA(1) - if not (_la==RelayParser.T__4): - break - - self.state = 473 - self.match(RelayParser.T__6) - pass - - elif la_ == 3: - self.enterOuterAlt(localctx, 3) - self.state = 475 - self.shape() - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class MetaContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def getRuleIndex(self): - return RelayParser.RULE_meta - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMeta" ): - return visitor.visitMeta(self) - else: - return visitor.visitChildren(self) - - - - - def meta(self): - - localctx = RelayParser.MetaContext(self, self._ctx, self.state) - self.enterRule(localctx, 58, self.RULE_meta) - try: - self.enterOuterAlt(localctx, 1) - self.state = 478 - self.match(RelayParser.T__27) - self.state = 479 - self.match(RelayParser.T__7) - self.state = 480 - self.match(RelayParser.CNAME) - self.state = 481 - self.match(RelayParser.T__8) - self.state = 482 - self.match(RelayParser.T__7) - self.state = 483 - self.match(RelayParser.NAT) - self.state = 484 - self.match(RelayParser.T__8) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ShapeContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_shape - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class ParensShapeContext(ShapeContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext - super().__init__(parser) - self.copyFrom(ctx) - - def shape(self): - return self.getTypedRuleContext(RelayParser.ShapeContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitParensShape" ): - return visitor.visitParensShape(self) - else: - return visitor.visitChildren(self) - - - class MetaShapeContext(ShapeContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext - super().__init__(parser) - self.copyFrom(ctx) - - def meta(self): - return self.getTypedRuleContext(RelayParser.MetaContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMetaShape" ): - return visitor.visitMetaShape(self) - else: - return visitor.visitChildren(self) - - - class IntShapeContext(ShapeContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext - super().__init__(parser) - self.copyFrom(ctx) - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIntShape" ): - return visitor.visitIntShape(self) - else: - return visitor.visitChildren(self) - - - - def shape(self): - - localctx = RelayParser.ShapeContext(self, self._ctx, self.state) - self.enterRule(localctx, 60, self.RULE_shape) - try: - self.state = 492 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__27]: - localctx = RelayParser.MetaShapeContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 486 - self.meta() - pass - elif token in [RelayParser.T__5]: - localctx = RelayParser.ParensShapeContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 487 - self.match(RelayParser.T__5) - self.state = 488 - self.shape() - self.state = 489 - self.match(RelayParser.T__6) - pass - elif token in [RelayParser.NAT]: - localctx = RelayParser.IntShapeContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 491 - self.match(RelayParser.NAT) - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class BodyContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_body - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitBody" ): - return visitor.visitBody(self) - else: - return visitor.visitChildren(self) - - - - - def body(self): - - localctx = RelayParser.BodyContext(self, self._ctx, self.state) - self.enterRule(localctx, 62, self.RULE_body) - try: - self.enterOuterAlt(localctx, 1) - self.state = 494 - self.match(RelayParser.T__11) - self.state = 495 - self.expr(0) - self.state = 496 - self.match(RelayParser.T__12) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ScalarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_scalar - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class ScalarFloatContext(ScalarContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext - super().__init__(parser) - self.copyFrom(ctx) - - def FLOAT(self): - return self.getToken(RelayParser.FLOAT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitScalarFloat" ): - return visitor.visitScalarFloat(self) - else: - return visitor.visitChildren(self) - - - class ScalarBoolContext(ScalarContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext - super().__init__(parser) - self.copyFrom(ctx) - - def BOOL_LIT(self): - return self.getToken(RelayParser.BOOL_LIT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitScalarBool" ): - return visitor.visitScalarBool(self) - else: - return visitor.visitChildren(self) - - - class ScalarIntContext(ScalarContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext - super().__init__(parser) - self.copyFrom(ctx) - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitScalarInt" ): - return visitor.visitScalarInt(self) - else: - return visitor.visitChildren(self) - - - - def scalar(self): - - localctx = RelayParser.ScalarContext(self, self._ctx, self.state) - self.enterRule(localctx, 64, self.RULE_scalar) - try: - self.state = 501 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.FLOAT]: - localctx = RelayParser.ScalarFloatContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 498 - self.match(RelayParser.FLOAT) - pass - elif token in [RelayParser.NAT]: - localctx = RelayParser.ScalarIntContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 499 - self.match(RelayParser.NAT) - pass - elif token in [RelayParser.BOOL_LIT]: - localctx = RelayParser.ScalarBoolContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 500 - self.match(RelayParser.BOOL_LIT) - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class IdentContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - - def globalVar(self): - return self.getTypedRuleContext(RelayParser.GlobalVarContext,0) - - - def localVar(self): - return self.getTypedRuleContext(RelayParser.LocalVarContext,0) - - - def graphVar(self): - return self.getTypedRuleContext(RelayParser.GraphVarContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_ident - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIdent" ): - return visitor.visitIdent(self) - else: - return visitor.visitChildren(self) - - - - - def ident(self): - - localctx = RelayParser.IdentContext(self, self._ctx, self.state) - self.enterRule(localctx, 66, self.RULE_ident) - try: - self.state = 507 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,53,self._ctx) - if la_ == 1: - self.enterOuterAlt(localctx, 1) - self.state = 503 - self.generalIdent() - pass - - elif la_ == 2: - self.enterOuterAlt(localctx, 2) - self.state = 504 - self.globalVar() - pass - - elif la_ == 3: - self.enterOuterAlt(localctx, 3) - self.state = 505 - self.localVar() - pass - - elif la_ == 4: - self.enterOuterAlt(localctx, 4) - self.state = 506 - self.graphVar() - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - - def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): - if self._predicates == None: - self._predicates = dict() - self._predicates[7] = self.expr_sempred - pred = self._predicates.get(ruleIndex, None) - if pred is None: - raise Exception("No predicate with index:" + str(ruleIndex)) - else: - return pred(localctx, predIndex) - - def expr_sempred(self, localctx:ExprContext, predIndex:int): - if predIndex == 0: - return self.precpred(self._ctx, 19) - - - if predIndex == 1: - return self.precpred(self._ctx, 18) - - - if predIndex == 2: - return self.precpred(self._ctx, 17) - - - if predIndex == 3: - return self.precpred(self._ctx, 16) - - - if predIndex == 4: - return self.precpred(self._ctx, 6) - - - if predIndex == 5: - return self.precpred(self._ctx, 21) - - - if predIndex == 6: - return self.precpred(self._ctx, 8) - - - - - diff --git a/python/tvm/relay/grammar/py3/RelayVisitor.py b/python/tvm/relay/grammar/py3/RelayVisitor.py deleted file mode 100644 index c6a7b7a0558c9..0000000000000 --- a/python/tvm/relay/grammar/py3/RelayVisitor.py +++ /dev/null @@ -1,343 +0,0 @@ -# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 -from antlr4 import * -if __name__ is not None and "." in __name__: - from .RelayParser import RelayParser -else: - from RelayParser import RelayParser - -# This class defines a complete generic visitor for a parse tree produced by RelayParser. - -class RelayVisitor(ParseTreeVisitor): - - # Visit a parse tree produced by RelayParser#prog. - def visitProg(self, ctx:RelayParser.ProgContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#generalIdent. - def visitGeneralIdent(self, ctx:RelayParser.GeneralIdentContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#globalVar. - def visitGlobalVar(self, ctx:RelayParser.GlobalVarContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#localVar. - def visitLocalVar(self, ctx:RelayParser.LocalVarContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#graphVar. - def visitGraphVar(self, ctx:RelayParser.GraphVarContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#exprList. - def visitExprList(self, ctx:RelayParser.ExprListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#callNoAttr. - def visitCallNoAttr(self, ctx:RelayParser.CallNoAttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#callWithAttr. - def visitCallWithAttr(self, ctx:RelayParser.CallWithAttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#funcExpr. - def visitFuncExpr(self, ctx:RelayParser.FuncExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#metaExpr. - def visitMetaExpr(self, ctx:RelayParser.MetaExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#match. - def visitMatch(self, ctx:RelayParser.MatchContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tensor. - def visitTensor(self, ctx:RelayParser.TensorContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#graph. - def visitGraph(self, ctx:RelayParser.GraphContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#identExpr. - def visitIdentExpr(self, ctx:RelayParser.IdentExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#stringExpr. - def visitStringExpr(self, ctx:RelayParser.StringExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#call. - def visitCall(self, ctx:RelayParser.CallContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#neg. - def visitNeg(self, ctx:RelayParser.NegContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tuple. - def visitTuple(self, ctx:RelayParser.TupleContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#paren. - def visitParen(self, ctx:RelayParser.ParenContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#scalarExpr. - def visitScalarExpr(self, ctx:RelayParser.ScalarExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#let. - def visitLet(self, ctx:RelayParser.LetContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#projection. - def visitProjection(self, ctx:RelayParser.ProjectionContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#ifElse. - def visitIfElse(self, ctx:RelayParser.IfElseContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#binOp. - def visitBinOp(self, ctx:RelayParser.BinOpContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#func. - def visitFunc(self, ctx:RelayParser.FuncContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#funcDefn. - def visitFuncDefn(self, ctx:RelayParser.FuncDefnContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#externAdtDefn. - def visitExternAdtDefn(self, ctx:RelayParser.ExternAdtDefnContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtDefn. - def visitAdtDefn(self, ctx:RelayParser.AdtDefnContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#constructorName. - def visitConstructorName(self, ctx:RelayParser.ConstructorNameContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtConsDefnList. - def visitAdtConsDefnList(self, ctx:RelayParser.AdtConsDefnListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtConsDefn. - def visitAdtConsDefn(self, ctx:RelayParser.AdtConsDefnContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#matchClauseList. - def visitMatchClauseList(self, ctx:RelayParser.MatchClauseListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#matchClause. - def visitMatchClause(self, ctx:RelayParser.MatchClauseContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#matchType. - def visitMatchType(self, ctx:RelayParser.MatchTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#patternList. - def visitPatternList(self, ctx:RelayParser.PatternListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#wildcardPattern. - def visitWildcardPattern(self, ctx:RelayParser.WildcardPatternContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#varPattern. - def visitVarPattern(self, ctx:RelayParser.VarPatternContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#constructorPattern. - def visitConstructorPattern(self, ctx:RelayParser.ConstructorPatternContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tuplePattern. - def visitTuplePattern(self, ctx:RelayParser.TuplePatternContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtCons. - def visitAdtCons(self, ctx:RelayParser.AdtConsContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtConsParamList. - def visitAdtConsParamList(self, ctx:RelayParser.AdtConsParamListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtConsParam. - def visitAdtConsParam(self, ctx:RelayParser.AdtConsParamContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#argNoAttr. - def visitArgNoAttr(self, ctx:RelayParser.ArgNoAttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#argWithAttr. - def visitArgWithAttr(self, ctx:RelayParser.ArgWithAttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#varList. - def visitVarList(self, ctx:RelayParser.VarListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#var. - def visitVar(self, ctx:RelayParser.VarContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#attrSeq. - def visitAttrSeq(self, ctx:RelayParser.AttrSeqContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#attr. - def visitAttr(self, ctx:RelayParser.AttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tupleType. - def visitTupleType(self, ctx:RelayParser.TupleTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#typeParen. - def visitTypeParen(self, ctx:RelayParser.TypeParenContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#typeCallType. - def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#typeIdentType. - def visitTypeIdentType(self, ctx:RelayParser.TypeIdentTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tensorType. - def visitTensorType(self, ctx:RelayParser.TensorTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#funcType. - def visitFuncType(self, ctx:RelayParser.FuncTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#incompleteType. - def visitIncompleteType(self, ctx:RelayParser.IncompleteTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#typeParamList. - def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#shapeList. - def visitShapeList(self, ctx:RelayParser.ShapeListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#meta. - def visitMeta(self, ctx:RelayParser.MetaContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#metaShape. - def visitMetaShape(self, ctx:RelayParser.MetaShapeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#parensShape. - def visitParensShape(self, ctx:RelayParser.ParensShapeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#intShape. - def visitIntShape(self, ctx:RelayParser.IntShapeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#body. - def visitBody(self, ctx:RelayParser.BodyContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#scalarFloat. - def visitScalarFloat(self, ctx:RelayParser.ScalarFloatContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#scalarInt. - def visitScalarInt(self, ctx:RelayParser.ScalarIntContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#scalarBool. - def visitScalarBool(self, ctx:RelayParser.ScalarBoolContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#ident. - def visitIdent(self, ctx:RelayParser.IdentContext): - return self.visitChildren(ctx) - - - -del RelayParser \ No newline at end of file diff --git a/python/tvm/relay/grammar/py3/__init__.py b/python/tvm/relay/grammar/py3/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/python/tvm/relay/grammar/py3/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py deleted file mode 100644 index 6c4e3131e3c26..0000000000000 --- a/python/tvm/relay/parser.py +++ /dev/null @@ -1,30 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""A parser for Relay's text format.""" -from __future__ import absolute_import -from .. import register_func - - -@register_func("relay.fromtext") -def fromtext(data, source_name=None): - """Parse a Relay program.""" - # pylint: disable=import-outside-toplevel - from tvm.relay import _parser - x = _parser.fromtext(data + "\n", source_name) - if x is None: - raise Exception("cannot parse: ", data) - return x diff --git a/python/tvm/relay/std/core.rly b/python/tvm/relay/std/core.rly index 6a3facc3424c8..f469491a56f1b 100644 --- a/python/tvm/relay/std/core.rly +++ b/python/tvm/relay/std/core.rly @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -v0.0.4 + +#[version = "0.0.5"] extern type Storage diff --git a/python/tvm/relay/std/gradient.rly b/python/tvm/relay/std/gradient.rly index ed81e4b2d4546..7594f4ebc5f48 100644 --- a/python/tvm/relay/std/gradient.rly +++ b/python/tvm/relay/std/gradient.rly @@ -16,7 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -v0.0.4 + +#[version = "0.0.5"] /* * Store the Gradient Value of a Tensor of type T. diff --git a/python/tvm/relay/std/prelude.rly b/python/tvm/relay/std/prelude.rly index fa05d1a7bd989..17c91283f4d2b 100644 --- a/python/tvm/relay/std/prelude.rly +++ b/python/tvm/relay/std/prelude.rly @@ -16,9 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -v0.0.4 - -// TODO(weberlo): should we add sugar for scalar types (e.g., `int32` => `Tensor[(), int32]`)? +#[version = "0.0.5"] def @id[A](%x: A) -> A { %x @@ -298,7 +296,7 @@ def @size[A](%t: Tree[A]) -> Tensor[(), int32] { * Takes a number n and a function f; returns a closure that takes an argument * and applies f n times to its argument. */ -def @iterate[A](%f: fn(A) -> A, %n: Tensor[(), int32]) -> (fn(A) -> A) { +def @iterate[A](%f: fn(A) -> A, %n: Tensor[(), int32]) -> fn(A) -> A { if (%n == 0) { @id } else { diff --git a/src/ir/module.cc b/src/ir/module.cc index 25ecab2455cbf..b34740865fc60 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -29,6 +29,7 @@ // and are only used in minimum cases where they are clearly marked. // // Rationale: We calls into relay's analysis module to verify correctness. +#include #include #include @@ -371,10 +372,7 @@ void IRModuleNode::ImportFromStd(const String& path) { std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } IRModule IRModule::FromText(const String& text, const String& source_path) { - auto* f = tvm::runtime::Registry::Get("relay.fromtext"); - CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; - IRModule mod = (*f)(text, source_path); - return mod; + return tvm::parser::ParseModule(source_path, text); } TVM_REGISTER_NODE_TYPE(IRModuleNode); diff --git a/src/ir/span.cc b/src/ir/span.cc index 64b42ab4dc144..d9c9bbc47c341 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -61,23 +61,35 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode) return static_cast(n)->name; }); -Span::Span(SourceName source, int lineno, int col_offset) { +Span::Span(SourceName source_name, int line, int end_line, int column, int end_column) { auto n = make_object(); - n->source = std::move(source); - n->line = lineno; - n->column = col_offset; + n->source_name = std::move(source_name); + n->line = line; + n->end_line = end_line; + n->column = column; + n->end_column = end_column; data_ = std::move(n); } +Span Span::Merge(const Span& other) { + CHECK((*this)->source_name == other->source_name); + return Span((*this)->source_name, std::min((*this)->line, other->line), + std::max((*this)->end_line, other->end_line), + std::min((*this)->column, other->column), + std::max((*this)->end_column, other->end_column)); +} + TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno, int col_offset) { - return Span(source, lineno, col_offset); +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line, + int column, int end_column) { + return Span(source_name, line, end_line, column, end_column); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source << ", " << node->line << ", " << node->column << ")"; + p->stream << "Span(" << node->source_name << ", " << node->line << ", " << node->end_line + << ", " << node->column << ", " << node->end_column << ")"; }); } // namespace tvm diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index 19f5d205126a9..2eb38b312242e 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -31,6 +31,7 @@ #define TVM_PARSER_DIAGNOSTIC_H_ #include +#include #include #include @@ -42,93 +43,17 @@ namespace tvm { namespace parser { -/*! \brief A program source in any language. - * - * Could represent the source from an ML framework or the internal - * source of a TVM program. - */ -struct Source { - /*! \brief The raw source. */ - std::string source; - /*! \brief A mapping of line breaks into the raw source. */ - std::vector> line_map; - - /*! \brief An empty source. */ - Source() : source(), line_map() {} - - /*! \brief Construct a source from a string. */ - explicit Source(const std::string& source) : source(source) { - int index = 0; - int length = 0; - line_map.push_back({index, length}); - for (auto c : source) { - if (c == '\n') { - // Record the length of the line. - line_map.back().second = length; - // Bump past the newline. - index += 1; - // Record the start of the next line, and put placeholder for length. - line_map.push_back({index, 0}); - // Reset length to zero. - length = 0; - } else { - length += 1; - index += 1; - } - } - line_map.back().second = length; - } - - Source(const Source& source) : source(source.source), line_map(source.line_map) {} - - /*! \brief Generate an error message at a specific line and column with the - * annotated message. - * - * The error is written directly to the `out` std::ostream. - * - * \param out The output ostream. - * \param line The line at which to report a diagnostic. - * \param line The column at which to report a diagnostic. - * \param msg The message to attach. - */ - void ReportAt(std::ostream& out, int line, int column, const std::string& msg) const { - CHECK(line - 1 <= static_cast(line_map.size())) - << "requested line: " << (line - 1) << "line_map size: " << line_map.size() - << "source: " << source; - - // Adjust for zero indexing, now have (line_start, line_length); - auto range = line_map.at(line - 1); - int line_start = range.first; - int line_length = range.second; - out << "file:" << line << ":" << column << ": parse error: " << msg << std::endl; - out << " " << source.substr(line_start, line_length) << std::endl; - out << " "; - std::stringstream marker; - for (int i = 1; i <= line_length; i++) { - if (i == column) { - marker << "^"; - } else if ((column - i) < 3) { - marker << "~"; - } else if ((i - column) < 3) { - marker << "~"; - } else { - marker << " "; - } - } - out << marker.str(); - out << std::endl; - } -}; - /*! \brief The diagnostic level, controls the printing of the message. */ -enum DiagnosticLevel { - Bug, - Error, - Warning, - Note, - Help, +enum class DiagnosticLevel { + kBug, + kError, + kWarning, + kNote, + kHelp, }; +struct DiagnosticBuilder; + /*! \brief A diagnostic message. */ struct Diagnostic { /*! \brief The level. */ @@ -138,10 +63,81 @@ struct Diagnostic { /*! \brief The diagnostic message. */ std::string message; - Diagnostic(int line, int column, const std::string& message) - : level(DiagnosticLevel::Error), span(SourceName(), line, column), message(message) {} + Diagnostic(DiagnosticLevel level, Span span, const std::string& message) + : level(level), span(span), message(message) {} + + static DiagnosticBuilder Bug(Span span); + static DiagnosticBuilder Error(Span span); + static DiagnosticBuilder Warning(Span span); + static DiagnosticBuilder Note(Span span); + static DiagnosticBuilder Help(Span span); }; +/*! + * \brief A wrapper around std::stringstream to build a diagnostic. + * + * \code + * + * void ReportError(const Error& err); + * + * void Test(int number) { + * // Use error reporter to construct an error. + * ReportError(ErrorBuilder() << "This is an error number=" << number); + * } + * + * \endcode + */ +struct DiagnosticBuilder { + public: + /*! \brief The level. */ + DiagnosticLevel level; + + /*! \brief The source name. */ + SourceName source_name; + + /*! \brief The span of the diagnostic. */ + Span span; + + template + DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) + stream_ << val; + return *this; + } + + DiagnosticBuilder() : level(DiagnosticLevel::kError), source_name(), span(Span()) {} + + DiagnosticBuilder(const DiagnosticBuilder& builder) + : level(builder.level), source_name(builder.source_name), span(builder.span) {} + + DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} + + operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); } + + private: + std::stringstream stream_; + friend struct Diagnostic; +}; + +DiagnosticBuilder Diagnostic::Bug(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kBug, span); +} + +DiagnosticBuilder Diagnostic::Error(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kError, span); +} + +DiagnosticBuilder Diagnostic::Warning(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kWarning, span); +} + +DiagnosticBuilder Diagnostic::Note(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kNote, span); +} + +DiagnosticBuilder Diagnostic::Help(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kHelp, span); +} + /*! \brief A diagnostic context for recording errors against a source file. * TODO(@jroesch): convert source map and improve in follow up PR, the parser * assumes a single global file for now. @@ -158,15 +154,22 @@ struct DiagnosticContext { /*! \brief Emit a diagnostic. */ void Emit(const Diagnostic& diagnostic) { diagnostics.push_back(diagnostic); } + /*! \brief Emit a diagnostic. */ + void EmitFatal(const Diagnostic& diagnostic) { + diagnostics.push_back(diagnostic); + Render(std::cout); + } + // TODO(@jroesch): eventually modularize the rendering interface to provide control of how to // format errors. void Render(std::ostream& ostream) { for (auto diagnostic : diagnostics) { - source.ReportAt(ostream, diagnostic.span->line, diagnostic.span->column, diagnostic.message); + source.ReportAt(ostream, diagnostic.span, diagnostic.message); } if (diagnostics.size()) { - LOG(FATAL) << "parse error occured"; + LOG(FATAL) << "DiagnosticError: one or more error diagnostics were " + << "emitted, please check diagnostic render for output."; } } }; diff --git a/src/parser/meta_ref.cc b/src/parser/meta_ref.cc new file mode 100644 index 0000000000000..d23892753c5fd --- /dev/null +++ b/src/parser/meta_ref.cc @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/parser/meta_ref.cc + * \brief An operator which allows forward referencing a yet-to-be parsed meta table reference. + */ + +#include "./meta_ref.h" + +#include +#include +#include +#include + +namespace tvm { +namespace parser { + +using tvm::relay::transform::CreateFunctionPass; +using tvm::transform::PassContext; + +/* Set to arbitrary high number, since we should never schedule in normal pass manager flow. */ +static int kMetaExpandOptLevel = 1337; + +TVM_REGISTER_NODE_TYPE(MetaRefAttrs); + +bool MetaRefRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + LOG(FATAL) << "need to expand before type checking"; + return true; +} + +RELAY_REGISTER_OP("parser.MetaRef") + .describe(R"code(A reference into the meta table.)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(0) + .set_support_level(10) + .add_type_rel("MetaRef", MetaRefRel) + .set_attr("TOpIsStateful", false) + .set_attr("TNonComputational", true); + +Expr MetaRef(std::string type_key, uint64_t node_index) { + static const Op& op = Op::Get("parser.MetaRef"); + auto attrs = make_object(); + attrs->node_type_key = tvm::String(type_key); + attrs->node_index = node_index; + return Call(op, {}, Attrs(attrs), {}); +} + +struct MetaRefExpander : public ExprMutator { + MetaTable table; + + explicit MetaRefExpander(const MetaTable& table) : table(table) {} + + Expr VisitExpr_(const CallNode* call) final { + if (auto op_node = call->op.as()) { + if (op_node->name == "parser.MetaRef") { + auto meta_attrs = call->attrs.as(); + CHECK(meta_attrs) << "an internal error has occurred"; + auto nodes = table.at(meta_attrs->node_type_key); + CHECK_LT(meta_attrs->node_index, nodes.size()); + return Downcast(nodes[meta_attrs->node_index]); + } + } + + return ExprMutator::VisitExpr_(call); + } +}; + +Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func) { + MetaRefExpander expander(meta_table); + return Downcast(expander.VisitExpr(func)); +} + +IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) { + auto pass = CreateFunctionPass([&](Function func, IRModule module, + PassContext ctx) { return ExpandMetaRefs(meta_table, func); }, + kMetaExpandOptLevel, "ExpandMetaRefs", {}); + + return pass(mod, PassContext::Create()); +} + +} // namespace parser +} // namespace tvm diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h new file mode 100644 index 0000000000000..481f334cb0fe0 --- /dev/null +++ b/src/parser/meta_ref.h @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file meta_ref.h + * \brief A reference into the metadata section of the Relay text format. + */ + +#ifndef TVM_PARSER_META_REF_H_ +#define TVM_PARSER_META_REF_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace parser { + +using namespace relay; + +using MetaTable = Map>; + +/*! + * \brief Options for allocating storage. + */ +struct MetaRefAttrs : public tvm::AttrsNode { + tvm::String node_type_key; + uint64_t node_index; + + TVM_DECLARE_ATTRS(MetaRefAttrs, "relay.attrs.MetaRefAttrs") { + TVM_ATTR_FIELD(node_type_key) + .describe("The type_key representing the type of the node referenced."); + TVM_ATTR_FIELD(node_index).describe("The index into the type specific node array."); + } +}; + +/*! \brief A reference to a "meta-expression". + * + * In the text format we allow referencing metadata which + * uses a compact serialization that proceeds the main + * program body. + * + * We can reference this table using an expression of + * the form `meta[Type][index]`. + * + * We must later resolve these references to actual in-memory + * AST nodes but this requires first parsing the full program + * then expanding these temporary AST nodes into their corresponding + * nodes. + * + * For example the nth large constant will be pretty-printed as meta[relay.Constant][n] + * with its compact binary serialization residing in the metadata section at the end + * of the program. + * + * \param type_key The type key of the object in the meta section. + * \param node_index The index into that subfield. + * \returns The meta table reference. + */ +Expr MetaRef(std::string type_key, uint64_t node_index); + +relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func); +IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod); + +} // namespace parser +} // namespace tvm + +#endif // TVM_PARSER_META_REF_H_ diff --git a/src/parser/op_table.h b/src/parser/op_table.h index 5af10a0590b80..050904f23280e 100644 --- a/src/parser/op_table.h +++ b/src/parser/op_table.h @@ -80,16 +80,16 @@ struct OperatorTable { OperatorTable DefaultOpTable() { return OperatorTable( - {Rule({TokenType::Star}, Op::Get("multiply"), 12, 2, true), - Rule({TokenType::Division}, Op::Get("divide"), 12, 2, true), - Rule({TokenType::Plus}, Op::Get("add"), 10, 2, true), - Rule({TokenType::Minus}, Op::Get("subtract"), 10, 2, true), - Rule({TokenType::LAngle}, Op::Get("less"), 8, 2, true), - Rule({TokenType::LAngle, TokenType::Equal}, Op::Get("less_equal"), 8, 2, true), - Rule({TokenType::RAngle}, Op::Get("greater"), 8, 2, true), - Rule({TokenType::RAngle, TokenType::Equal}, Op::Get("greater_equal"), 8, 2, true), - Rule({TokenType::Equal, TokenType::Equal}, Op::Get("equal"), 7, 2, true), - Rule({TokenType::Bang, TokenType::Equal}, Op::Get("not_equal"), 7, 2, true)}); + {Rule({TokenType::kStar}, Op::Get("multiply"), 12, 2, true), + Rule({TokenType::kDivision}, Op::Get("divide"), 12, 2, true), + Rule({TokenType::kPlus}, Op::Get("add"), 10, 2, true), + Rule({TokenType::kMinus}, Op::Get("subtract"), 10, 2, true), + Rule({TokenType::kLAngle}, Op::Get("less"), 8, 2, true), + Rule({TokenType::kLAngle, TokenType::kEqual}, Op::Get("less_equal"), 8, 2, true), + Rule({TokenType::kRAngle}, Op::Get("greater"), 8, 2, true), + Rule({TokenType::kRAngle, TokenType::kEqual}, Op::Get("greater_equal"), 8, 2, true), + Rule({TokenType::kEqual, TokenType::kEqual}, Op::Get("equal"), 7, 2, true), + Rule({TokenType::kBang, TokenType::kEqual}, Op::Get("not_equal"), 7, 2, true)}); } } // namespace parser diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 0aaa698be45e8..7877245725ea2 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -32,6 +32,7 @@ #include #include "./diagnostic.h" +#include "./meta_ref.h" #include "./op_table.h" #include "./tokenizer.h" @@ -41,6 +42,9 @@ namespace parser { using namespace relay; using Expr = relay::Expr; +/*! \brief The meta table maps from type key to a sequence of objects. */ +using MetaTable = Map>; + /*! \brief A wrapper structure for capturing the result of parsing * a global definition *before* we add it to the IRModule. * @@ -87,60 +91,6 @@ class SemVer { patch_version(other.patch_version) {} }; -/*! \brief A reference to a "meta-expression". - * - * In the text format we allow referencing metadata which - * uses a compact serialization that proceeds the main - * program body. - * - * We can reference this table using an expression of - * the form `meta[Type][index]`. - * - * We must later resolve these references to actual in-memory - * AST nodes but this requires first parsing the full program - * then expanding these temporary AST nodes into their corresponding - * nodes. - * - * For example the nth large constant will be pretty-printed as meta[relay.Constant][n] - * with its compact binary serialization residing in the metadata section at the end - * of the program. - */ -class MetaRefExprNode : public TempExprNode { - public: - /*! \brief The type key of the meta expression. */ - std::string type_key; - /*! \brief The index into the type key's table. */ - uint64_t node_index; - - void VisitAttrs(tvm::AttrVisitor* v) {} - - // TODO(@jroesch): we probably will need to manually - // expand these with a pass. - Expr Realize() const final { return Expr(); } - - static constexpr const char* _type_key = "relay.MetaRefExpr"; - TVM_DECLARE_FINAL_OBJECT_INFO(MetaRefExprNode, TempExprNode); -}; - -class MetaRefExpr : public TempExpr { - public: - /*! - * \brief The constructor for MetaRefExpr - * \param type_key The type key of the object in the meta section. - * \param kind The index into that subfield. - */ - TVM_DLL MetaRefExpr(std::string type_key, uint64_t node_index); - - TVM_DEFINE_OBJECT_REF_METHODS(MetaRefExpr, TempExpr, MetaRefExprNode); -}; - -MetaRefExpr::MetaRefExpr(std::string type_key, uint64_t node_index) { - auto rnode = make_object(); - rnode->type_key = type_key; - rnode->node_index = node_index; - data_ = std::move(rnode); -} - /*! \brief A simple wrapper around a mapping from raw string names * to a TVM variable, type variable or other binder type. */ @@ -164,6 +114,7 @@ template class ScopeStack { private: std::vector> scope_stack; + std::unordered_map free_vars; public: /*! \brief Adds a variable binding to the current scope. */ @@ -174,6 +125,8 @@ class ScopeStack { this->scope_stack.back().name_map.insert({name, value}); } + void AddFreeVar(const std::string& name, const T& value) { free_vars.insert({name, value}); } + /*! \brief Looks up a variable name in the scope stack returning the matching variable * in most recent scope. */ T Lookup(const std::string& name) { @@ -183,6 +136,13 @@ class ScopeStack { return it->second; } } + + // Check if we bound a free variable declaration. + auto it = free_vars.find(name); + if (it != free_vars.end()) { + return it->second; + } + return T(); } @@ -193,17 +153,22 @@ class ScopeStack { void PopStack() { this->scope_stack.pop_back(); } }; +struct DuplicateKeyError : public dmlc::Error { + explicit DuplicateKeyError(const std::string& msg) : dmlc::Error(msg) {} +}; + /*! \brief A table of interning strings as global function and type names. */ template struct InternTable { /*! \brief The internal table mapping strings to a unique allocation. */ std::unordered_map table; + DiagnosticContext* ctx; /*! \brief Add the unique allocation. */ void Add(const std::string& name, const T& t) { auto it = table.find(name); if (it != table.end()) { - LOG(FATAL) << "duplicate name"; + throw DuplicateKeyError("duplicate key name in intern table"); } else { table.insert({name, t}); } @@ -264,7 +229,9 @@ class Parser { SemVer version; /*! \brief The diagnostic context used for error reporting. */ - DiagnosticContext diag_ctx; + DiagnosticContext* diag_ctx; + + const SourceName& source_name; /*! \brief The current position in the token stream. */ int pos; @@ -296,8 +263,18 @@ class Parser { /*! \brief The set of expression scopes used for lexical scope. */ ScopeStack expr_scopes; - Parser(std::vector tokens, OperatorTable op_table, Source source) - : diag_ctx(source), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true) {} + /*! \brief The metadata section. */ + MetaTable meta_table; + + Parser(DiagnosticContext* ctx, const SourceName& source_name, std::vector tokens, + OperatorTable op_table, Source source, MetaTable table) + : diag_ctx(ctx), + source_name(source_name), + pos(0), + tokens(tokens), + op_table(op_table), + ignore_whitespace(true), + meta_table(table) {} /*! \brief Examine the next token in the stream, the current parser is configured to be * whitespace insensitive so we will skip all whitespace or comment tokens. */ @@ -305,10 +282,10 @@ class Parser { // For now we ignore all whitespace tokens and comments. // We can tweak this behavior later to enable white space sensitivity in the parser. while (pos < static_cast(tokens.size()) && ignore_whitespace && - (tokens.at(pos)->token_type == TokenType::Whitespace || - tokens.at(pos)->token_type == TokenType::Newline || - tokens.at(pos)->token_type == TokenType::LineComment || - tokens.at(pos)->token_type == TokenType::Comment)) { + (tokens.at(pos)->token_type == TokenType::kWhitespace || + tokens.at(pos)->token_type == TokenType::kNewline || + tokens.at(pos)->token_type == TokenType::kLineComment || + tokens.at(pos)->token_type == TokenType::kComment)) { pos++; } @@ -345,10 +322,9 @@ class Parser { */ void Consume(const TokenType& token_type) { if (tokens[pos]->token_type != token_type) { - std::string message = - "expected a " + Pretty(token_type) + " found " + Pretty(Peek()->token_type); - this->diag_ctx.Emit({tokens[pos]->line, tokens[pos]->column, message}); - this->diag_ctx.Render(std::cout); + this->diag_ctx->EmitFatal(Diagnostic::Error(tokens[pos]->span) + << "expected a " << Pretty(token_type) << " found " + << Pretty(Peek()->token_type)); } pos++; } @@ -409,6 +385,17 @@ class Parser { return var; } + /*! \brief Bind a local variable in the expression scope. + * + * "x" -> Var("x"), these are needed to map from the raw string names + * to unique variable nodes. + */ + Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) { + auto var = Var(name, type_annotation); + this->expr_scopes.AddFreeVar(name, var); + return var; + } + /*! \brief Bind a type variable in the type scope. * * "A" -> TypeVar("A", ...), these are needed to map from raw string names @@ -427,8 +414,8 @@ class Parser { Var LookupLocal(const Token& local) { auto var = this->expr_scopes.Lookup(local.ToString()); if (!var.defined()) { - diag_ctx.Emit( - {local->line, local->column, "this local variable has not been previously declared"}); + diag_ctx->Emit(Diagnostic::Error(local->span) + << "this local variable has not been previously declared"); } return var; } @@ -440,9 +427,9 @@ class Parser { TypeVar LookupTypeVar(const Token& ident) { auto var = this->type_scopes.Lookup(ident.ToString()); if (!var.defined()) { - diag_ctx.Emit( - {ident->line, ident->column, - "this type variable has not been previously declared anywhere, perhaps a typo?"}); + diag_ctx->Emit( + Diagnostic::Error(ident->span) + << "this type variable has not been previously declared anywhere, perhaps a typo?"); } return var; } @@ -469,7 +456,7 @@ class Parser { /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ NDArray NumberToNDArray(const Token& token) { - if (token->token_type == TokenType::Integer) { + if (token->token_type == TokenType::kInteger) { DLContext ctx = {DLDeviceType::kDLCPU, 0}; auto dtype = String2DLDataType("int32"); auto data = NDArray::Empty({}, dtype, ctx); @@ -478,7 +465,7 @@ class Parser { int64_t value = Downcast(token->data); array[0] = (int32_t)value; return data; - } else if (token->token_type == TokenType::Float) { + } else if (token->token_type == TokenType::kFloat) { DLContext ctx = {DLDeviceType::kDLCPU, 0}; auto dtype = String2DLDataType("float32"); auto data = NDArray::Empty({}, dtype, ctx); @@ -520,15 +507,38 @@ class Parser { /*! \brief Parse `(` parser() `)`. */ template R Parens(std::function parser) { - return Bracket(TokenType::OpenParen, TokenType::CloseParen, parser); + return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser); } /*! \brief Parse `{` parser() `}`. */ template R Block(std::function parser) { - return Bracket(TokenType::LCurly, TokenType::RCurly, parser); + return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); } + ObjectRef ParseMetaRef() { + auto meta_ref = Match(TokenType::kMetaReference); + Call ref = Downcast(meta_ref->data); + auto attrs = ref->attrs.as(); + auto type_key = attrs->node_type_key; + auto index = attrs->node_index; + auto it = this->meta_table.find(type_key); + if (it != this->meta_table.end()) { + auto nodes = (*it).second; + if (index < nodes.size()) { + return nodes[index]; + } else { + this->diag_ctx->Emit(Diagnostic::Error(meta_ref->span) + << "the node index `" << index << "` is out of bounds for `" + << type_key << "`"); + return ObjectRef(); + } + } else { + this->diag_ctx->Emit(Diagnostic::Error(meta_ref->span) + << "no entry in the meta table for `" << type_key << "`"); + return ObjectRef(); + } + } /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and * ending with a stop token. * @@ -542,31 +552,44 @@ class Parser { */ template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, - std::function before_stop = nullptr) { + std::function before_stop = nullptr) { + DLOG(INFO) << "Parser::ParseSequence: start=" << ToString(start) << "sep=" << ToString(sep) + << "stop=" << ToString(stop); Match(start); + + // This is for the empty arguments list case, if we have token stream + // we must parse leftovers, then match a stop token. + if (before_stop) { + auto did_parse = before_stop(); + if (did_parse) { + Match(stop); + return {}; + } + } + + // This is the case in which we find an empty arguments lists and no leftovers. if (WhenMatch(stop)) { return Array(); } else { auto data = parse(); Array elements = {data}; - // parse '(' expr ')' - // if we are at the end invoke leftover parser - if (Peek()->token_type == stop && before_stop) { - before_stop(); - } if (WhenMatch(stop)) { return elements; // parse '( expr ',' * ')' } else if (WhenMatch(sep)) { - // if we are at the end invoke leftover parser - if (Peek()->token_type == stop && before_stop) { - before_stop(); - } while (true) { if (WhenMatch(stop)) { break; } else { + // If before stop is + if (before_stop) { + auto did_parse = before_stop(); + if (did_parse) { + Match(stop); + return elements; + } + } auto data = parse(); WhenMatch(sep); elements.push_back(data); @@ -574,7 +597,10 @@ class Parser { } return elements; } else { - LOG(FATAL) << "issue"; + auto next = Peek(); + this->diag_ctx->EmitFatal(Diagnostic::Error(next->span) + << "expected a " << Pretty(stop) << " found " + << Pretty(next->token_type)); return Array(nullptr); } } @@ -588,7 +614,8 @@ class Parser { auto defs = ParseDefinitions(); // Parse the metadata section at the end. auto metadata = ParseMetadata(); - Match(TokenType::EndOfFile); + + Match(TokenType::kEndOfFile); Map funcs; Map types; @@ -606,24 +633,21 @@ class Parser { } /*! \brief Parse the semantic versioning header. */ - SemVer ParseSemVer() { - // TODO(@jroesch): convert semver to module level attribute. - auto id = Peek(); - if (id->token_type == TokenType::Identifier && id.ToString() == "v0") { - auto id = Match(TokenType::Identifier); - Consume(TokenType::Period); - Consume(TokenType::Float); + SemVer ParseSemVer(bool required = true) { + if (Peek()->token_type == TokenType::kVersion) { + auto version = Match(TokenType::kVersion); + // TODO(@jroesch): we currently only support 0.0.5. + if (version.ToString() != "\"0.0.5\"") { + this->diag_ctx->Emit(Diagnostic::Error(version->span) + << "invalid semantic version `" << version.ToString() << "`"); + } + } else if (required) { + this->diag_ctx->Emit(Diagnostic::Error(Peek()->span) + << "expected text format semantic version, found a " + << PrettyPrint(Peek()) + << "you can annotate it as #[version = \"0.0.5\"]"); } - // TODO(@jroesch): the current lexing makes it hard to parse this - // in a way that doesnt feel like a hack. - // - // We should move to module level attributes instead - // so we can tag modules with top-level data. - // - // #[text_version = "0.0.4"] - // - // For now we only support current version. - return SemVer(0, 0, 4); + return SemVer(0, 0, 5); } /*! \brief Parse zero or more Relay definitions. */ @@ -633,25 +657,32 @@ class Parser { while (true) { auto next = Peek(); switch (next->token_type) { - case TokenType::Defn: { - Consume(TokenType::Defn); - auto global_name = Match(TokenType::Global).ToString(); + case TokenType::kDefn: { + Consume(TokenType::kDefn); + auto global_tok = Match(TokenType::kGlobal); + auto global_name = global_tok.ToString(); auto global = GlobalVar(global_name); - global_names.Add(global_name, global); + try { + global_names.Add(global_name, global); + } catch (DuplicateKeyError e) { + this->diag_ctx->Emit(Diagnostic::Error(global_tok->span) << "a function with the name " + << "`@" << global_name << "` " + << "was previously defined"); + } auto func = ParseFunctionDef(); defs.funcs.push_back(GlobalFunc(global, func)); continue; } - case TokenType::TypeDef: { + case TokenType::kTypeDef: { defs.types.push_back(ParseTypeDef()); continue; } - case TokenType::Extern: { - Consume(TokenType::Extern); + case TokenType::kExtern: { + Consume(TokenType::kExtern); auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { - diag_ctx.Emit( - {next->line, next->column, "an external type may not have any constructors"}); + diag_ctx->Emit(Diagnostic::Error(next->span) + << "an external type may not have any constructors"); } defs.types.push_back(type_def); } @@ -664,48 +695,64 @@ class Parser { /*! \brief Parse zero or more Relay type definitions. */ TypeData ParseTypeDef() { // Match the `type` keyword. - Match(TokenType::TypeDef); + Match(TokenType::kTypeDef); // Parse the type's identifier. - auto type_id = Match(TokenType::Identifier).ToString(); + auto type_tok = Match(TokenType::kIdentifier); + auto type_id = type_tok.ToString(); auto type_global = tvm::GlobalTypeVar(type_id, TypeKind::kAdtHandle); - type_names.Add(type_id, type_global); + + try { + type_names.Add(type_id, type_global); + } catch (DuplicateKeyError e) { + this->diag_ctx->Emit(Diagnostic::Error(type_tok->span) << "a type definition with the name " + << "`" << type_id << "` " + << "was previously defined"); + } Array generics; bool should_pop = false; - if (Peek()->token_type == TokenType::LSquare) { + if (Peek()->token_type == TokenType::kLSquare) { // If we have generics we need to add a type scope. PushTypeScope(); should_pop = true; - generics = - ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() { - auto type_var_name = Match(TokenType::Identifier).ToString(); + generics = ParseSequence( + TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { + auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); } Array ctors; - if (Peek()->token_type == TokenType::LCurly) { + if (Peek()->token_type == TokenType::kLCurly) { // Parse the list of constructors. ctors = ParseSequence( - TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&]() { + TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() { // First match the name of the constructor. - auto ctor_name = Match(TokenType::Identifier).ToString(); + auto ctor_tok = Match(TokenType::kIdentifier); + auto ctor_name = ctor_tok.ToString(); Constructor ctor; // Match the optional field list. - if (Peek()->token_type != TokenType::OpenParen) { + if (Peek()->token_type != TokenType::kOpenParen) { ctor = tvm::Constructor(ctor_name, {}, type_global); } else { auto arg_types = - ParseSequence(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, - [&]() { return ParseType(); }); + ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { return ParseType(); }); ctor = tvm::Constructor(ctor_name, arg_types, type_global); } CHECK(ctor.defined()); - this->ctors.Add(ctor_name, ctor); + try { + this->ctors.Add(ctor_name, ctor); + } catch (DuplicateKeyError e) { + this->diag_ctx->EmitFatal(Diagnostic::Error(ctor_tok->span) + << "a constructor with the name " + << "`" << ctor_name << "` " + << "was previously defined"); + } return ctor; }); @@ -745,41 +792,62 @@ class Parser { /*! \brief Parse a single Relay expression. */ Expr ParseExpr() { + DLOG(INFO) << "Parser::ParseExpr"; return ConsumeWhitespace([this] { std::vector exprs; while (true) { + DLOG(INFO) << "Parser::ParseExpr: parsing a single expression"; auto next = Peek(); switch (next->token_type) { // For graph or let, match first rhs, then invoke ParseBindingExpr // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue - case TokenType::LCurly: { + case TokenType::kLCurly: { // NB: Might need to optimize to remove deep recursion. // Stack should only grow proportionally to the number of // nested scopes. - return Bracket(TokenType::LCurly, TokenType::RCurly, [&]() { + // Parses `{` expression `}`. + auto block = Bracket(TokenType::kLCurly, TokenType::kRCurly, [&]() { PushScope(); auto expr = ParseExpr(); PopScopes(1); return expr; }); + exprs.push_back(block); + break; } - case TokenType::Let: + case TokenType::kFreeVar: { + Consume(TokenType::kFreeVar); + auto var_token = Match(TokenType::kLocal); + + Type type; + if (WhenMatch(TokenType::kColon)) { + type = ParseType(); + } else { + type = IncompleteType(); + } + + BindFreeVar(var_token.ToString(), type); + break; + } + // Parses `let ...`; + case TokenType::kLet: exprs.push_back(ParseBindingExpr()); break; - case TokenType::Match: - case TokenType::PartialMatch: { - bool is_total = next->token_type == TokenType::Match; + case TokenType::kMatch: + case TokenType::kPartialMatch: { + bool is_total = next->token_type == TokenType::kMatch; Consume(next->token_type); exprs.push_back(ParseMatch(is_total)); break; } - case TokenType::If: { + case TokenType::kIf: { exprs.push_back(ParseIf()); break; } - case TokenType::Graph: - if (Lookahead(2)->token_type == TokenType::Equal) { + // %x ... + case TokenType::kGraph: + if (Lookahead(2)->token_type == TokenType::kEqual) { exprs.push_back(ParseBindingExpr()); break; } @@ -790,7 +858,7 @@ class Parser { } } - if (!WhenMatch(TokenType::Semicolon)) { + if (!WhenMatch(TokenType::kSemicolon)) { break; } } @@ -838,39 +906,40 @@ class Parser { // This ensures for n sequential bindings // the call depth will be the same before // and after parsing the n bindings. + DLOG(INFO) << "Parser::ParseBindingExpr"; std::vector> bindings; int scopes = 0; while (true) { auto next = Peek(); - if (next->token_type == TokenType::Graph && Lookahead(2)->token_type == TokenType::Equal) { - Match(TokenType::Graph); - Match(TokenType::Equal); + if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) { + Match(TokenType::kGraph); + Match(TokenType::kEqual); auto val = this->ParseExprBinOp(); - Match(TokenType::Semicolon); + Match(TokenType::kSemicolon); AddGraphBinding(next, val); - } else if (next->token_type == TokenType::Let) { + } else if (next->token_type == TokenType::kLet) { // Parse the 'let'. - Consume(TokenType::Let); + Consume(TokenType::kLet); // Parse the local '%'. - auto local_tok = Match(TokenType::Local); + auto local_tok = Match(TokenType::kLocal); auto string = local_tok.ToString(); // Parse the optional type annotation (':' ). Type type; - if (WhenMatch(TokenType::Colon)) { + if (WhenMatch(TokenType::kColon)) { type = ParseType(); } auto var = BindVar(string, type); // Parse the '='; - Match(TokenType::Equal); + Match(TokenType::kEqual); // Parse the body, and the ';'. auto val = this->ParseExprBinOp(); - Consume(TokenType::Semicolon); + Consume(TokenType::kSemicolon); // Add the bindings to the local data structure. bindings.push_back({var, val}); @@ -905,37 +974,52 @@ class Parser { /*! Parse a function definition without a leading keyword or identifier. * - * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN, UN) -> Ret { body }. + * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. */ Function ParseFunctionDef() { + DLOG(INFO) << "Parser::ParseFunctionDef"; PushScope(); PushTypeScope(); Array generics; - if (Peek()->token_type == TokenType::LSquare) { + if (Peek()->token_type == TokenType::kLSquare) { // If we have generics we need to add a type scope. PushTypeScope(); - generics = - ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() { - auto type_var_name = Match(TokenType::Identifier).ToString(); + generics = ParseSequence( + TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { + auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); } - auto params = - ParseSequence(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, [&]() { - auto token = Match(TokenType::Local); + Map raw_attrs; + + auto params = ParseSequence( + TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + [&]() { + auto token = Match(TokenType::kLocal); auto string = token.ToString(); Type type; - if (WhenMatch(TokenType::Colon)) { + if (WhenMatch(TokenType::kColon)) { type = ParseType(); } return BindVar(string, type); + }, + [&] { + auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; + auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; + + if (is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; + } + + return false; }); Type ret_type; - if (WhenMatch(TokenType::Minus)) { - Match(TokenType::RAngle); + if (WhenMatch(TokenType::kMinus)) { + Match(TokenType::kRAngle); ret_type = ParseType(); } @@ -944,26 +1028,42 @@ class Parser { PopTypeScopes(1); PopScopes(1); - return relay::Function(params, body, ret_type, generics); + // TODO(@jroesch): attributes should never be null, they should always be empty. + if (raw_attrs.size()) { + return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + } else { + return relay::Function(params, body, ret_type, generics); + } } /*! \brief Parse an if-expression. */ Expr ParseIf() { - Consume(TokenType::If); + DLOG(INFO) << "Parser::ParseIf"; + Consume(TokenType::kIf); auto guard = Parens([&] { return ParseExpr(); }); - auto true_branch = Block([&] { return ParseExpr(); }); + auto true_branch = Block([&] { + this->PushScope(); + auto expr = ParseExpr(); + this->PopScopes(1); + return expr; + }); - Match(TokenType::Else); + Match(TokenType::kElse); - auto false_branch = Block([&] { return ParseExpr(); }); + auto false_branch = Block([&] { + this->PushScope(); + auto expr = ParseExpr(); + this->PopScopes(1); + return expr; + }); return relay::If(guard, true_branch, false_branch); } /* This factors parsing a list of patterns for both tuples, and constructors. */ Array ParsePatternList() { - return ParseSequence(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, + return ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&] { return ParsePattern(); }); } @@ -975,26 +1075,27 @@ class Parser { * This function recursively parses a pattern. */ Pattern ParsePattern() { + DLOG(INFO) << "Parser::ParsePattern"; auto next = Peek(); switch (next->token_type) { - case TokenType::Underscore: { - Match(TokenType::Underscore); + case TokenType::kUnderscore: { + Match(TokenType::kUnderscore); return PatternWildcard(); } - case TokenType::Local: { - auto id = Match(TokenType::Local); + case TokenType::kLocal: { + auto id = Match(TokenType::kLocal); Type type_annotation; - if (WhenMatch(TokenType::Colon)) { + if (WhenMatch(TokenType::kColon)) { type_annotation = ParseType(); } auto var = BindVar(id.ToString(), type_annotation); return PatternVar(var); } - case TokenType::Identifier: { - auto id = Match(TokenType::Identifier); + case TokenType::kIdentifier: { + auto id = Match(TokenType::kIdentifier); auto ctor = ctors.Get(id.ToString()); CHECK(ctor) << "undefined identifier"; - if (Peek()->token_type == TokenType::OpenParen) { + if (Peek()->token_type == TokenType::kOpenParen) { auto fields = ParsePatternList(); return PatternConstructor(ctor.value(), fields); } else { @@ -1009,8 +1110,8 @@ class Parser { Clause ParseMatchArm() { PushScope(); auto pattern = ParsePattern(); - Match(TokenType::Equal); - Consume(TokenType::RAngle); + Match(TokenType::kEqual); + Consume(TokenType::kRAngle); auto expr = ParseExpr(); PopScopes(1); return Clause(pattern, expr); @@ -1020,12 +1121,13 @@ class Parser { Expr scrutinee = ParseExpr(); Array clauses = ParseSequence( - TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&] { return ParseMatchArm(); }); + TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&] { return ParseMatchArm(); }); return relay::Match(scrutinee, clauses, is_total); } Expr ParseExprBinOp() { + DLOG(INFO) << "Parser::ParseExprBinOp"; return ConsumeWhitespace([this] { // We must parse at least one expression, the default // case is that there is no operator and we will fall @@ -1098,87 +1200,186 @@ class Parser { }); } - Attrs ParseAttrs(const std::string& type_key) { + ObjectRef ParseAttributeValue() { + DLOG(INFO) << "Parser::ParseAttributeValue"; + auto next = Peek(); + switch (next->token_type) { + case TokenType::kFloat: + case TokenType::kInteger: + case TokenType::kBoolean: + case TokenType::kStringLiteral: + return Match(next->token_type)->data; + case TokenType::kLSquare: { + return ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, + [&]() { return ParseAttributeValue(); }); + } + case TokenType::kOpenParen: { + // TODO(@jroesch: need to figure out bracket vs. sequence) + // return ParseSequence(TokenType::kOpenParen, TokenType::kComma, + // TokenType::kCloseParen, + // [&]() { return ParseAttributeValue(); }); + return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, + [&]() { return ParseAttributeValue(); }); + } + // TODO(@jroesch): not sure about this being the right way to handle nulls. + case TokenType::kIdentifier: { + if (auto text = next->data.as()) { + std::string id = GetRef(text); + if (id == "nullptr") { + Match(TokenType::kIdentifier); + return ObjectRef(); + } + } + } + default: + return ParseAtomicExpr(); + } + } + + Map ParseAttrs() { + DLOG(INFO) << "Parser::ParseAttrs"; Map kwargs; - auto attrs = tvm::ReflectionVTable::Global()->CreateObject(type_key, kwargs); - LOG(FATAL) << Attrs(); - return Attrs(); + while (Peek()->token_type == TokenType::kIdentifier) { + auto key = Match(TokenType::kIdentifier).ToString(); + Match(TokenType::kEqual); + // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. + auto value = ParseAttributeValue(); + // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text + // format is bad. + kwargs.Set(key, value); + WhenMatch(TokenType::kComma); + } + DLOG(INFO) << "Parser::ParseAttrs: kwargs=" << kwargs; + return kwargs; } Expr ParseCallArgs(Expr op) { - Attrs call_attrs; - if (Peek()->token_type == TokenType::OpenParen) { - Array args = ParseSequence( - TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, - [&] { return ParseExpr(); }, - [&] { - auto is_ident = Lookahead(1)->token_type == TokenType::Identifier; - auto next_is_equal = Lookahead(2)->token_type == TokenType::Equal; - - if (is_ident && next_is_equal) { - if (auto op_node = op.as()) { - call_attrs = ParseAttrs(op_node->attrs_type_key); + try { + DLOG(INFO) << "Parser::ParseCallArgs"; + Map raw_attrs; + std::string op_key; + bool is_op = false; + + if (auto op_node = op.as()) { + is_op = true; + op_key = op_node->attrs_type_key; + } + + if (Peek()->token_type == TokenType::kOpenParen) { + Array args = ParseSequence( + TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + [&] { return ParseExpr(); }, + [&] { + auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; + auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; + + if (is_op && is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; } - } - }); - return Expr(Call(op, args, call_attrs, {})); - } else { - return Expr(); + + return false; + }); + + Attrs attrs; + + if (is_op && op_key.size()) { + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); + CHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } + + return Expr(Call(op, args, attrs, {})); + } else { + return Expr(); + } + } catch (...) { + // TODO(@jroesch): AttrErrors should have fields + this->diag_ctx->Emit(Diagnostic::Error(Peek()->span)); + // << err.what()); } + + return Expr(); } Expr ParseCallExpr() { + DLOG(INFO) << "Parser::ParseCallExpr"; return ConsumeWhitespace([this] { Expr expr = ParseAtomicExpr(); // Parse as many call args as possible, building up expression // // NB(@jroesch): this seems like a hack but in order to parse curried functions // and avoid complex grammar we will parse multiple call lists in a row. - while (true) { - auto new_expr = ParseCallArgs(expr); - if (new_expr.defined()) { - expr = new_expr; - } else { - break; + while (Peek()->token_type == TokenType::kOpenParen) { + try { + auto new_expr = ParseCallArgs(expr); + + if (new_expr.defined()) { + expr = new_expr; + } else { + break; + } + } catch (...) { + // TODO(@jroesch): AttrErrors should have fields + this->diag_ctx->EmitFatal(Diagnostic::Error(Peek()->span)); + // << err.what()); } } // We need a zero-arity case for constructors. - if (expr.as()) { - return Expr(Call(expr, {})); - } else { - return expr; + if (auto ctor_node = expr.as()) { + if (ctor_node->inputs.size() == 0) { + return Expr(Call(expr, {})); + } } + + return expr; }); } + Expr GetOp(const std::string& op_name, const Token& tok) { + DLOG(INFO) << "op_name=" << op_name << " token=" << tok; + try { + return Op::Get(op_name); + } catch (dmlc::Error e) { + this->diag_ctx->Emit(Diagnostic::Error(tok->span) + << "operator `" << op_name + << "` not found, perhaps you forgot to register it?"); + return Expr(); + } + } + Expr ParseAtomicExpr() { - return ConsumeWhitespace([this] { + DLOG(INFO) << "Parser::ParseAtomicExpr"; + auto expr = ConsumeWhitespace([this] { auto next = Peek(); switch (next->token_type) { - case TokenType::Integer: - case TokenType::Float: { + case TokenType::kInteger: + case TokenType::kFloat: { Consume(next->token_type); auto number = NumberToNDArray(next); - Expr e = Constant(number); + Expr e = Constant(number, next->span); return e; } - case TokenType::Boolean: { - Consume(TokenType::Boolean); + case TokenType::kBoolean: { + Consume(TokenType::kBoolean); int value = Downcast(next->data); auto boolean = BooleanToNDarray(value); - Expr e = Constant(boolean); + Expr e = Constant(boolean, next->span); return e; } - case TokenType::Local: { - Consume(TokenType::Local); + // Parse a local of the form `%x`. + case TokenType::kLocal: { + Consume(TokenType::kLocal); return Expr(LookupLocal(next)); } - case TokenType::Global: { + // Parse a local of the form `@x`. + case TokenType::kGlobal: { auto string = next.ToString(); - Consume(TokenType::Global); + Consume(TokenType::kGlobal); auto global = global_names.Get(string); if (!global) { + // TODO(@jroesch): fix global's needing span information auto global_var = GlobalVar(string); global_names.Add(string, global_var); return Expr(global_var); @@ -1186,43 +1387,59 @@ class Parser { return Expr(global.value()); } } - case TokenType::Identifier: { - auto string = next.ToString(); - Consume(TokenType::Identifier); - auto ctor = ctors.Get(string); + // Parse a local of the form `x`. + // Right now we fail to parse `x.y`. + case TokenType::kIdentifier: { + auto ctor = ctors.Get(next.ToString()); if (ctor) { + Consume(TokenType::kIdentifier); return Expr(ctor.value()); } else { - return Expr(Op::Get(string)); + auto idents = ParseHierarchicalName(); + CHECK_NE(idents.size(), 0); + std::stringstream op_name; + int i = 0; + int periods = idents.size() - 1; + for (auto ident : idents) { + op_name << ident; + if (i < periods) { + op_name << "."; + i++; + } + } + return GetOp(op_name.str(), next); } } - case TokenType::Graph: { - Consume(TokenType::Graph); + case TokenType::kGraph: { + Consume(TokenType::kGraph); return LookupGraphBinding(next); } - case TokenType::Fn: { - Consume(TokenType::Fn); + case TokenType::kMetaReference: { + return Downcast(ParseMetaRef()); + } + case TokenType::kFn: { + Consume(TokenType::kFn); return Expr(ParseFunctionDef()); } - case TokenType::OpenParen: { - Consume(TokenType::OpenParen); + case TokenType::kOpenParen: { + Consume(TokenType::kOpenParen); // parse '(' ')' - if (WhenMatch(TokenType::CloseParen)) { + if (WhenMatch(TokenType::kCloseParen)) { return Expr(Tuple(Array())); } else { auto expr = ParseExpr(); // parse '(' expr ')' - if (WhenMatch(TokenType::CloseParen)) { + if (WhenMatch(TokenType::kCloseParen)) { return expr; // parse '( expr ',' * ')' - } else if (WhenMatch(TokenType::Comma)) { + } else if (WhenMatch(TokenType::kComma)) { Array exprs = {expr}; while (true) { - if (WhenMatch(TokenType::CloseParen)) { + if (WhenMatch(TokenType::kCloseParen)) { break; } else { auto expr = ParseExpr(); - WhenMatch(TokenType::Comma); + WhenMatch(TokenType::kComma); exprs.push_back(expr); } } @@ -1231,33 +1448,76 @@ class Parser { } } default: { - std::stringstream msg; - msg << "expected an expression found " << Pretty(next->token_type); - diag_ctx.Emit({next->line, next->column, msg.str()}); - diag_ctx.Render(std::cout); + this->diag_ctx->EmitFatal(Diagnostic::Error(next->span) + << "expected an expression found " + << Pretty(next->token_type)); return Expr(); } } }); + + if (WhenMatch(TokenType::kPeriod)) { + auto index = Match(TokenType::kInteger).ToNumber(); + expr = relay::TupleGetItem(expr, index); + } + + return expr; + } + + /*! \brief Parse a hierarchical name. + * + * The tokenizer produces a token stream of . + * and so on for names of the form `nn.conv2d`. + * Currently we only use string names everywhere instead + * of a notion of a hierarchical name. + * + * The below utility reassembles a token stream into a + * single stream inserting the required periods needed + * to look up registered names. + */ + Array ParseHierarchicalName() { + Array idents; + while (Peek()->token_type == TokenType::kIdentifier) { + auto name = Peek().ToString(); + idents.push_back(name); + Consume(TokenType::kIdentifier); + + // Keep parsing while we see a trailing period. + if (Peek()->token_type == TokenType::kPeriod) { + Consume(TokenType::kPeriod); + continue; + } else { + // No more periods means we are done! + break; + } + } + + return idents; } /*! \brief Parse a shape. */ Array ParseShape() { - auto dims = ParseSequence(TokenType::OpenParen, TokenType::Comma, - TokenType::CloseParen, [&]() { - auto tok = Match(TokenType::Integer); - return Downcast(tok->data); - }); + auto dims = ParseSequence( + TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { + tvm::PrimExpr dim; + if (Peek()->token_type == TokenType::kMetaReference) { + dim = Downcast(ParseMetaRef()); + } else { + dim = Downcast(Match(TokenType::kInteger)->data); + } + + return dim; + }); return dims; } /*! \brief Parse a function type. */ Type ParseFunctionType() { - auto ty_params = ParseSequence(TokenType::OpenParen, TokenType::Comma, - TokenType::CloseParen, [&]() { return ParseType(); }); + auto ty_params = ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { return ParseType(); }); - Match(TokenType::Minus); - Match(TokenType::RAngle); + Match(TokenType::kMinus); + Match(TokenType::kRAngle); auto ret_type = ParseType(); return relay::FuncType(ty_params, ret_type, {}, {}); @@ -1278,8 +1538,8 @@ class Parser { CHECK(head_type.defined()) << "internal error: head type must be defined"; Array arg_types; - if (Peek()->token_type == TokenType::LSquare) { - arg_types = ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, + if (Peek()->token_type == TokenType::kLSquare) { + arg_types = ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseType(); }); } @@ -1298,21 +1558,21 @@ class Parser { Type ParseType() { auto tok = Peek(); - if (tok->token_type == TokenType::OpenParen) { - auto tys = ParseSequence(TokenType::OpenParen, TokenType::Comma, - TokenType::CloseParen, [&]() { return ParseType(); }); + if (tok->token_type == TokenType::kOpenParen) { + auto tys = ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { return ParseType(); }); return relay::TupleType(tys); - } else if (WhenMatch(TokenType::Fn)) { + } else if (WhenMatch(TokenType::kFn)) { return ParseFunctionType(); - } else if (WhenMatch(TokenType::Identifier)) { + } else if (WhenMatch(TokenType::kIdentifier)) { auto id = tok.ToString(); if (id == "Tensor") { - Match(TokenType::LSquare); + Match(TokenType::kLSquare); auto shape = ParseShape(); - Match(TokenType::Comma); - auto dtype_tok = Match(TokenType::Identifier); + Match(TokenType::kComma); + auto dtype_tok = Match(TokenType::kIdentifier); auto dtype = DataType(String2DLDataType(dtype_tok.ToString())); - Match(TokenType::RSquare); + Match(TokenType::kRSquare); return TensorType(shape, dtype); } else { auto ty = tok.ToString(); @@ -1326,14 +1586,11 @@ class Parser { } } } - if (WhenMatch(TokenType::Underscore)) { + if (WhenMatch(TokenType::kUnderscore)) { return IncompleteType(); } else { - std::stringstream msg; - msg << "failed to parse type found "; - msg << tok; - diag_ctx.Emit({tok->line, tok->column, msg.str()}); - diag_ctx.Render(std::cout); + this->diag_ctx->EmitFatal(Diagnostic::Error(tok->span) + << "failed to parse type found " << tok); return Type(); } } @@ -1342,7 +1599,7 @@ class Parser { R ConsumeWhitespace(std::function func) { auto old = this->ignore_whitespace; this->ignore_whitespace = true; - while (tokens[pos]->token_type == TokenType::Whitespace) { + while (tokens[pos]->token_type == TokenType::kWhitespace) { pos++; } auto res = func(); @@ -1350,8 +1607,13 @@ class Parser { return res; } - // TODO(@jroesch): this is the final remaining feature. - ObjectRef ParseMetadata() { return ObjectRef(); } + Map> ParseMetadata() { + if (Peek()->token_type == TokenType::kMetadata) { + return Match(TokenType::kMetadata).ToMetadata(); + } else { + return Map>(); + } + } /*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */ void DisplayNextN(int n) { @@ -1380,27 +1642,49 @@ class Parser { }; IRModule ParseModule(std::string file_name, std::string file_content) { - auto tokens = Tokenize(file_content); - Parser parser(tokens, DefaultOpTable(), Source(file_content)); - return parser.ParseModule(); + DLOG(INFO) << "ParseModule"; + SourceName src_name = SourceName::Get(file_name); + Source src(src_name, file_content); + DiagnosticContext ctx(src); + auto tokens_and_table = Tokenize(&ctx, src_name, file_content); + auto tokens = tokens_and_table.first; + auto meta_data_table = tokens_and_table.second; + Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata()); + auto mod = parser.ParseModule(); + // NB(@jroesch): it is very important that we render any errors before we procede + // if there were any errors which allow the parser to procede we must render them + // here. + parser.diag_ctx->Render(std::cout); + return mod; } Expr ParseExpr(std::string file_name, std::string file_content) { - auto tokens = Tokenize(file_content); - Parser parser(tokens, DefaultOpTable(), Source(file_content)); + DLOG(INFO) << "ParseExpr"; + SourceName src_name = SourceName::Get(file_name); + Source src(src_name, file_content); + DiagnosticContext ctx(src); + auto tokens_and_table = Tokenize(&ctx, src_name, file_content); + auto tokens = tokens_and_table.first; + auto meta_data_table = tokens_and_table.second; + Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata()); + parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); - parser.Match(TokenType::EndOfFile); + parser.Match(TokenType::kEndOfFile); + // NB(@jroesch): it is very important that we render any errors before we procede + // if there were any errors which allow the parser to procede we must render them + // here. + parser.diag_ctx->Render(std::cout); return expr; } TVM_REGISTER_GLOBAL("parser.ParseModule") - .set_body_typed([](std::string file_name, std::string file_content) { + .set_body_typed([](tvm::String file_name, tvm::String file_content) { return ParseModule(file_name, file_content); }); TVM_REGISTER_GLOBAL("parser.ParseExpr") - .set_body_typed([](std::string file_name, std::string file_content) { + .set_body_typed([](tvm::String file_name, tvm::String file_content) { return ParseExpr(file_name, file_content); }); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc new file mode 100644 index 0000000000000..a2efdb5a88fd5 --- /dev/null +++ b/src/parser/source_map.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file source_map.cc + * \brief The implementation of the source map data structure. + */ +#include +#include + +namespace tvm { +namespace parser { + +/*! \brief Construct a source from a string. */ +Source::Source(const SourceName& src_name, const std::string& source) + : source_name(src_name), source(source) { + int index = 0; + int length = 0; + line_map.push_back({index, length}); + for (auto c : source) { + if (c == '\n') { + // Record the length of the line. + line_map.back().second = length; + // Bump past the newline. + index += 1; + // Record the start of the next line, and put placeholder for length. + line_map.push_back({index, 0}); + // Reset length to zero. + length = 0; + } else { + length += 1; + index += 1; + } + } + line_map.back().second = length; +} + +/*! \brief Generate an error message at a specific line and column with the + * annotated message. + * + * The error is written directly to the `out` std::ostream. + * + * \param out The output ostream. + * \param line The line at which to report a diagnostic. + * \param line The column at which to report a diagnostic. + * \param msg The message to attach. + */ +void Source::ReportAt(std::ostream& out, const Span& span, const std::string& msg) const { + DLOG(INFO) << "Source::ReportAt" + << "span = " << span << "msg = " << msg; + int line = span->line; + int column = span->column; + + CHECK(line - 1 <= static_cast(line_map.size())) + << "requested line: " << (line - 1) << "line_map size: " << line_map.size() + << "source: " << source; + + // Adjust for zero indexing, now have (line_start, line_length); + auto range = line_map.at(line - 1); + int line_start = range.first; + int line_length = range.second; + out << "file:" << line << ":" << column << ": parse error: " << msg << std::endl; + out << " " << source.substr(line_start, line_length) << std::endl; + out << " "; + std::stringstream marker; + for (int i = 1; i <= line_length; i++) { + if (i == column) { + marker << "^"; + } else if ((column - i) < 3) { + marker << "~"; + } else if ((i - column) < 3) { + marker << "~"; + } else { + marker << " "; + } + } + out << marker.str(); + out << std::endl; +} + +// TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); + +// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { +// auto* node = static_cast(ref.get()); +// p->stream << "SourceName(" << node->name << ", " << node << ")"; +// }); + +TVM_REGISTER_NODE_TYPE(SourceMapNode); + +SourceMap::SourceMap(Map source_map) { + auto n = make_object(); + n->source_map = std::move(source_map); + data_ = std::move(n); +} + +} // namespace parser +} // namespace tvm diff --git a/src/parser/token.h b/src/parser/token.h index d7aac23ca3504..3750ec568cc84 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -25,6 +25,7 @@ #ifndef TVM_PARSER_TOKEN_H_ #define TVM_PARSER_TOKEN_H_ +#include #include #include @@ -37,160 +38,172 @@ namespace parser { using namespace runtime; -enum TokenType { - CommentStart, - CommentEnd, - LineComment, - Comment, - Whitespace, - Newline, - StringLiteral, - Identifier, - Local, - Global, - Op, - Graph, - OpenParen, - CloseParen, - AtSymbol, - Percent, - Comma, - Period, - Equal, - Semicolon, - Colon, - Integer, - Float, - Division, - Boolean, - Plus, - Star, - Minus, - RAngle, - LAngle, - RCurly, - LCurly, - RSquare, - LSquare, - Bang, - At, - Question, - If, - Else, - Underscore, - Let, - Fn, - Defn, - TypeDef, - Extern, - Match, - PartialMatch, - Unknown, - EndOfFile, - Null, +enum class TokenType { + kCommentStart, + kCommentEnd, + kLineComment, + kComment, + kWhitespace, + kNewline, + kStringLiteral, + kIdentifier, + kLocal, + kGlobal, + kOp, + kGraph, + kOpenParen, + kCloseParen, + kAtSymbol, + kPercent, + kComma, + kPeriod, + kEqual, + kSemicolon, + kColon, + kInteger, + kFloat, + kDivision, + kBoolean, + kPlus, + kStar, + kMinus, + kRAngle, + kLAngle, + kRCurly, + kLCurly, + kRSquare, + kLSquare, + kBang, + kAt, + kQuestion, + kIf, + kElse, + kUnderscore, + kLet, + kFn, + kDefn, + kTypeDef, + kExtern, + kMatch, + kPartialMatch, + kMetadata, + kMetaReference, + kFreeVar, + kVersion, + kUnknown, + kEndOfFile, + kNull, }; std::string ToString(const TokenType& token_type) { switch (token_type) { - case TokenType::CommentStart: + case TokenType::kCommentStart: return "CommentStart"; - case TokenType::CommentEnd: + case TokenType::kCommentEnd: return "CommentEnd"; - case TokenType::LineComment: + case TokenType::kLineComment: return "LineComment"; - case TokenType::Comment: + case TokenType::kComment: return "Comment"; - case TokenType::Whitespace: + case TokenType::kWhitespace: return "WhiteSpace"; - case TokenType::Newline: + case TokenType::kNewline: return "Newline"; - case TokenType::StringLiteral: + case TokenType::kStringLiteral: return "StringLiteral"; - case TokenType::Identifier: + case TokenType::kIdentifier: return "Identifier"; - case TokenType::Local: + case TokenType::kLocal: return "Local"; - case TokenType::Global: + case TokenType::kGlobal: return "Global"; - case TokenType::Graph: + case TokenType::kGraph: return "Graph"; - case TokenType::Op: + case TokenType::kOp: return "Op"; - case TokenType::OpenParen: + case TokenType::kOpenParen: return "OpenParen"; - case TokenType::CloseParen: + case TokenType::kCloseParen: return "CloseParen"; - case TokenType::AtSymbol: + case TokenType::kAtSymbol: return "AtSymbol"; - case TokenType::Percent: + case TokenType::kPercent: return "Percent"; - case TokenType::Comma: + case TokenType::kComma: return "Comma"; - case TokenType::Colon: + case TokenType::kColon: return "Colon"; - case TokenType::Semicolon: + case TokenType::kSemicolon: return "Semicolon"; - case TokenType::Period: + case TokenType::kPeriod: return "Period"; - case TokenType::Equal: + case TokenType::kEqual: return "Equal"; - case TokenType::Integer: + case TokenType::kInteger: return "Integer"; - case TokenType::Float: + case TokenType::kFloat: return "Float"; - case TokenType::Plus: + case TokenType::kPlus: return "Plus"; - case TokenType::Star: + case TokenType::kStar: return "Star"; - case TokenType::Minus: + case TokenType::kMinus: return "Minus"; - case TokenType::Division: + case TokenType::kDivision: return "Division"; - case TokenType::RAngle: + case TokenType::kRAngle: return "RAngle"; - case TokenType::LAngle: + case TokenType::kLAngle: return "LAngle"; - case TokenType::RCurly: + case TokenType::kRCurly: return "RCurly"; - case TokenType::LCurly: + case TokenType::kLCurly: return "LCurly"; - case TokenType::RSquare: + case TokenType::kRSquare: return "RSquare"; - case TokenType::LSquare: + case TokenType::kLSquare: return "LSquare"; - case TokenType::Bang: + case TokenType::kBang: return "Bang"; - case TokenType::Underscore: + case TokenType::kUnderscore: return "Underscore"; - case TokenType::At: + case TokenType::kAt: return "At"; - case TokenType::Let: + case TokenType::kLet: return "Let"; - case TokenType::If: + case TokenType::kIf: return "If"; - case TokenType::Else: + case TokenType::kElse: return "Else"; - case TokenType::Fn: + case TokenType::kFn: return "Fn"; - case TokenType::Defn: + case TokenType::kDefn: return "Defn"; - case TokenType::TypeDef: + case TokenType::kTypeDef: return "TypeDef"; - case TokenType::Extern: + case TokenType::kExtern: return "Extern"; - case TokenType::Match: + case TokenType::kMatch: return "Match"; - case TokenType::PartialMatch: + case TokenType::kPartialMatch: return "PartialMatch"; - case TokenType::Question: + case TokenType::kQuestion: return "Question"; - case TokenType::Boolean: + case TokenType::kBoolean: return "Boolean"; - case TokenType::Unknown: + case TokenType::kMetadata: + return "Metadata"; + case TokenType::kMetaReference: + return "MetaReference"; + case TokenType::kFreeVar: + return "FreeVar"; + case TokenType::kVersion: + return "Version"; + case TokenType::kUnknown: return "Unknown"; - case TokenType::EndOfFile: + case TokenType::kEndOfFile: return "EndOfFile"; - case TokenType::Null: + case TokenType::kNull: return "Null"; // Older compilers warn even though the above code is exhaustive. default: @@ -201,106 +214,114 @@ std::string ToString(const TokenType& token_type) { std::string Pretty(const TokenType& token_type) { switch (token_type) { - case TokenType::CommentStart: + case TokenType::kCommentStart: return "`/*`"; - case TokenType::CommentEnd: + case TokenType::kCommentEnd: return "`*/`"; - case TokenType::LineComment: + case TokenType::kLineComment: return "`//`"; - case TokenType::Comment: + case TokenType::kComment: return "comment"; - case TokenType::Whitespace: + case TokenType::kWhitespace: return "whitespace"; - case TokenType::Newline: + case TokenType::kNewline: return "newline"; - case TokenType::StringLiteral: + case TokenType::kStringLiteral: return "string literal"; - case TokenType::Identifier: + case TokenType::kIdentifier: return "identifier"; - case TokenType::Local: + case TokenType::kLocal: return "local variable"; - case TokenType::Global: + case TokenType::kGlobal: return "global variable"; - case TokenType::Graph: + case TokenType::kGraph: return "graph variable"; - case TokenType::Op: + case TokenType::kOp: return "operator"; - case TokenType::OpenParen: + case TokenType::kOpenParen: return "`(`"; - case TokenType::CloseParen: + case TokenType::kCloseParen: return "`)`"; - case TokenType::AtSymbol: + case TokenType::kAtSymbol: return "`@`"; - case TokenType::Percent: + case TokenType::kPercent: return "`%`"; - case TokenType::Comma: + case TokenType::kComma: return "`,`"; - case TokenType::Colon: + case TokenType::kColon: return "`:`"; - case TokenType::Semicolon: + case TokenType::kSemicolon: return "`;`"; - case TokenType::Period: + case TokenType::kPeriod: return "`.`"; - case TokenType::Equal: + case TokenType::kEqual: return "`=`"; - case TokenType::Integer: + case TokenType::kInteger: return "integer"; - case TokenType::Float: + case TokenType::kFloat: return "float"; - case TokenType::Plus: + case TokenType::kPlus: return "`+`"; - case TokenType::Star: + case TokenType::kStar: return "`*`"; - case TokenType::Minus: + case TokenType::kMinus: return "`-`"; - case TokenType::Division: + case TokenType::kDivision: return "`/`"; - case TokenType::RAngle: + case TokenType::kRAngle: return "`<`"; - case TokenType::LAngle: + case TokenType::kLAngle: return "`>`"; - case TokenType::RCurly: + case TokenType::kRCurly: return "`}`"; - case TokenType::LCurly: + case TokenType::kLCurly: return "`{`"; - case TokenType::RSquare: + case TokenType::kRSquare: return "`]`"; - case TokenType::LSquare: + case TokenType::kLSquare: return "`[`"; - case TokenType::Bang: + case TokenType::kBang: return "`!`"; - case TokenType::Underscore: + case TokenType::kUnderscore: return "`_`"; - case TokenType::At: + case TokenType::kAt: return "`@`"; - case TokenType::Let: + case TokenType::kLet: return "`let`"; - case TokenType::If: + case TokenType::kIf: return "`if`"; - case TokenType::Else: + case TokenType::kElse: return "`else`"; - case TokenType::Fn: + case TokenType::kFn: return "`fn`"; - case TokenType::Defn: + case TokenType::kDefn: return "`def`"; - case TokenType::TypeDef: + case TokenType::kTypeDef: return "`type`"; - case TokenType::Extern: + case TokenType::kExtern: return "`extern`"; - case TokenType::Boolean: + case TokenType::kBoolean: return "boolean"; - case TokenType::Match: + case TokenType::kMetadata: + return "metadata section"; + case TokenType::kMetaReference: + return "`meta`"; + case TokenType::kFreeVar: + return "`free_var`"; + case TokenType::kMatch: return "`match`"; - case TokenType::PartialMatch: + case TokenType::kPartialMatch: return "`match?`"; - case TokenType::Question: + case TokenType::kQuestion: return "`?`"; - case TokenType::Unknown: + case TokenType::kUnknown: return "unknown"; - case TokenType::EndOfFile: + case TokenType::kEndOfFile: return "end of file"; - case TokenType::Null: + case TokenType::kNull: return "null"; + case TokenType::kVersion: + return "version attribute"; // Older compilers warn even though the above code is exhaustive. default: LOG(FATAL) << "unreachable code"; @@ -312,8 +333,7 @@ class Token; class TokenNode : public Object { public: - int line; - int column; + Span span; TokenType token_type; mutable runtime::ObjectRef data; @@ -326,37 +346,46 @@ class TokenNode : public Object { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "Token(line=" << node->line << ", column=" << node->column - << ", token_type=" << ToString(node->token_type) << ", data=" << node->data << ")"; + p->stream << "Token(span=" << node->span << ", token_type=" << ToString(node->token_type) + << ", data=" << node->data << ")"; }); TVM_REGISTER_NODE_TYPE(TokenNode); class Token : public ObjectRef { public: - TVM_DLL explicit Token(int line, int column, TokenType token_type, ObjectRef data = ObjectRef()); + TVM_DLL explicit Token(Span span, TokenType token_type, ObjectRef data = ObjectRef()); static Token Null(); int64_t ToNumber() const; std::string ToString() const; + Map> ToMetadata() const; TVM_DEFINE_OBJECT_REF_METHODS(Token, ObjectRef, TokenNode); }; -Token::Token(int line, int column, TokenType token_type, ObjectRef data) { +Token::Token(Span span, TokenType token_type, ObjectRef data) { ObjectPtr n = make_object(); - n->line = line; - n->column = column; + n->span = span; n->token_type = token_type; n->data = data; data_ = std::move(n); } -Token Token::Null() { return Token(0, 0, TokenType::Null); } +Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::kNull); } int64_t Token::ToNumber() const { return Downcast(this->operator->()->data); } std::string Token::ToString() const { return Downcast(this->operator->()->data); } +Map> Token::ToMetadata() const { + ObjectRef data = this->operator->()->data; + if (data.defined()) { + return Downcast>>(data); + } else { + return Map>({}); + } +} + } // namespace parser } // namespace tvm #endif // TVM_PARSER_TOKEN_H_ diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index f6c27340e09a8..88a49290dc3d7 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -24,14 +24,17 @@ #ifndef TVM_PARSER_TOKENIZER_H_ #define TVM_PARSER_TOKENIZER_H_ +#include #include #include #include #include #include +#include #include +#include "./meta_ref.h" #include "./token.h" namespace tvm { @@ -39,6 +42,17 @@ namespace parser { using namespace runtime; +// trim from start (in place) +static inline void ltrim(std::string& s) { // NOLINT(*) + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); })); +} + +// trim from end (in place) +static inline void rtrim(std::string& s) { // NOLINT(*) + s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(), + s.end()); +} + bool IsDigit(char c) { return '0' <= c && c <= '9'; } bool IsWhitespace(char c) { return ' ' == c || c == '\t' || c == '\n'; } @@ -53,11 +67,16 @@ bool IsIdentLetter(char c) { return '_' == c || ('a' <= c && c <= 'z') || ('A' < bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); } static std::unordered_map KEYWORD_TABLE = { - {"let", TokenType::Let}, {"fn", TokenType::Fn}, {"def", TokenType::Defn}, - {"if", TokenType::If}, {"else", TokenType::Else}, {"type", TokenType::TypeDef}, - {"match", TokenType::Match}, {"extern", TokenType::Extern}}; + {"let", TokenType::kLet}, {"fn", TokenType::kFn}, + {"def", TokenType::kDefn}, {"if", TokenType::kIf}, + {"else", TokenType::kElse}, {"type", TokenType::kTypeDef}, + {"match", TokenType::kMatch}, {"extern", TokenType::kExtern}, + {"free_var", TokenType::kFreeVar}}; struct Tokenizer { + DiagnosticContext* diag_ctx; + const SourceName& source_name; + size_t pos; int col; int line; @@ -84,8 +103,16 @@ struct Tokenizer { return this->source.at(this->pos); } - Token NewToken(TokenType token_type, ObjectRef data = ObjectRef()) { - return Token(this->line, this->col, token_type, data); + Token NewToken(TokenType token_type, ObjectRef data = ObjectRef(), int lines = 0, int cols = 1) { + auto span = + Span(this->source_name, this->line, this->line + lines, this->col, this->col + cols); + return Token(span, token_type, data); + } + + Span SpanFrom(int line, int column) { + int end_line = this->line; + int end_column = this->col; + return Span(this->source_name, line, end_line, column, end_column); } enum CommentParserState { @@ -104,7 +131,7 @@ struct Tokenizer { CommentParserState state = CommentParserState::Proceed; int nesting = 1; - while (true) { + while (More()) { switch (state) { case CommentParserState::Proceed: { if (Peek() == '/') { @@ -130,11 +157,11 @@ struct Tokenizer { Next(); buffer->pop_back(); return; - } else { - buffer->operator+=(Next()); - state = CommentParserState::Proceed; } } + + buffer->operator+=(Next()); + state = CommentParserState::Proceed; continue; } } @@ -148,7 +175,7 @@ struct Tokenizer { if (is_float) { throw std::invalid_argument("is_float"); } - auto token = NewToken(TokenType::Integer); + auto token = NewToken(TokenType::kInteger); size_t index = 0; int value = std::stoi(number, &index); if (number.size() > index) { @@ -158,7 +185,7 @@ struct Tokenizer { token->data = tvm::Integer(value); return token; } catch (const std::invalid_argument& ia) { - auto token = NewToken(TokenType::Float); + auto token = NewToken(TokenType::kFloat); if (number.back() == 'f') { number.pop_back(); @@ -171,27 +198,132 @@ struct Tokenizer { } } + bool MatchString(const std::string& string) { + int start = this->pos; + + for (auto c : string) { + if (Peek() != c) { + this->pos = start; + return false; + } else { + Next(); + } + } + + return true; + } + + Token TokenizeMetaRef() { + int line = this->line; + int column = this->col; + + CHECK_EQ(Peek(), '['); + Next(); + std::stringstream type_key; + while (More() && Peek() != ']') { + type_key << Next(); + } + CHECK_EQ(Peek(), ']'); + Next(); + + CHECK_EQ(Peek(), '['); + Next(); + std::stringstream str_index; + while (More() && Peek() != ']') { + str_index << Next(); + } + CHECK_EQ(Peek(), ']'); + Next(); + // todo: add error handling around bad indices + auto index = ParseNumber(true, false, str_index.str()).ToNumber(); + auto span = SpanFrom(line, column); + return Token(span, TokenType::kMetaReference, MetaRef(type_key.str(), index)); + } + + Token TokenizeAttr() { + int line = this->line; + int column = this->col; + Next(); + if (Peek() == '[') { + Next(); + std::stringstream raw_attribute; + + while (More() && Peek() != ']') { + raw_attribute << Next(); + } + + CHECK_EQ(Next(), ']'); + + auto attribute = raw_attribute.str(); + // Clean up the white-space on both sides. + ltrim(attribute); + rtrim(attribute); + + // Metadata can only appear at the bottom of a file and goes to EOF. + if (attribute == "metadata") { + std::stringstream metadata; + while (More()) { + metadata << Next(); + } + ObjectRef metadata_map = tvm::LoadJSON(metadata.str()); + auto span = SpanFrom(line, column); + return Token(span, TokenType::kMetadata, metadata_map); + } + if (attribute.rfind("version", 0) == 0) { + std::string version = attribute.substr(attribute.find("=") + 1); + ltrim(version); + rtrim(version); + auto span = SpanFrom(line, column); + return Token(span, TokenType::kVersion, tvm::String(version)); + } else { + // TOOD(@jroesch): maybe make this a warning an continue parsing? + auto span = SpanFrom(line, column); + this->diag_ctx->EmitFatal(Diagnostic::Error(span) << "unsupported attribute " << attribute); + return Token(); + } + } else { + auto span = SpanFrom(line, column); + this->diag_ctx + ->EmitFatal(Diagnostic::Error(span) + << "`#` denotes the start of an attribute can only be followed by `[`" + << " found `" << Peek() << "`"); + return Token(); + } + } + inline Token TokenizeOnce() { + int line = this->line; + int col = this->col; auto next = Peek(); + DLOG(INFO) << "tvm::parser::TokenizeOnce: next=" << next; if (next == '\n') { - auto token = NewToken(TokenType::Newline); + auto token = NewToken(TokenType::kNewline); Next(); return token; } else if (next == '\r') { Next(); if (More() && Peek() == '\n') { - auto token = NewToken(TokenType::Newline); + auto token = NewToken(TokenType::kNewline); return token; } else { - // TODO(@jroesch): have lexer use diagnostic context too. - LOG(FATAL) << "lexer error"; + auto span = SpanFrom(line, col); + this->diag_ctx->EmitFatal( + Diagnostic::Error(span) + << "\\r carriage returns must be followed by a \\n in the TVM text format"); return Token(); } } else if (next == '"') { - LOG(FATAL) << "string not working yet"; - return NewToken(TokenType::Unknown); + // TODO(@jroesch): Properly tokenize escape sequences in strings. + // see https://github.com/apache/incubator-tvm/issues/6153. + Next(); + std::stringstream string_content; + while (More() && Peek() != '"') { + string_content << Next(); + } + Next(); + return NewToken(TokenType::kStringLiteral, tvm::String(string_content.str())); } else if (IsWhitespace(next)) { - auto token = NewToken(TokenType::Whitespace); + auto token = NewToken(TokenType::kWhitespace); Next(); return token; } else if (IsDigit(next) || next == '-') { @@ -205,7 +337,7 @@ struct Tokenizer { // with multi-token return or something. if (negs && !IsDigit(Peek())) { pos = pos - (negs - 1); - return NewToken(TokenType::Minus); + return NewToken(TokenType::kMinus); } bool is_neg = negs % 2 == 1; @@ -223,89 +355,106 @@ struct Tokenizer { return ParseNumber(!is_neg, is_float, ss.str()); } else if (next == '.') { - auto token = NewToken(TokenType::Period); + auto token = NewToken(TokenType::kPeriod); Next(); return token; } else if (next == ',') { - auto token = NewToken(TokenType::Comma); + auto token = NewToken(TokenType::kComma); Next(); return token; } else if (next == '=') { - auto token = NewToken(TokenType::Equal); + auto token = NewToken(TokenType::kEqual); Next(); return token; } else if (next == ';') { - auto token = NewToken(TokenType::Semicolon); + auto token = NewToken(TokenType::kSemicolon); Next(); return token; } else if (next == ':') { - auto token = NewToken(TokenType::Colon); + auto token = NewToken(TokenType::kColon); Next(); return token; } else if (next == '(') { - auto token = NewToken(TokenType::OpenParen); + auto token = NewToken(TokenType::kOpenParen); Next(); return token; } else if (next == ')') { - auto token = NewToken(TokenType::CloseParen); + auto token = NewToken(TokenType::kCloseParen); Next(); return token; } else if (next == '+') { - auto token = NewToken(TokenType::Plus); + auto token = NewToken(TokenType::kPlus); Next(); return token; } else if (next == '-') { - auto token = NewToken(TokenType::Minus); + auto token = NewToken(TokenType::kMinus); Next(); return token; } else if (next == '*') { - auto token = NewToken(TokenType::Star); + auto token = NewToken(TokenType::kStar); Next(); return token; } else if (next == '<') { - auto token = NewToken(TokenType::LAngle); + auto token = NewToken(TokenType::kLAngle); Next(); return token; } else if (next == '>') { - auto token = NewToken(TokenType::RAngle); + auto token = NewToken(TokenType::kRAngle); Next(); return token; } else if (next == '{') { - auto token = NewToken(TokenType::LCurly); + auto token = NewToken(TokenType::kLCurly); Next(); return token; } else if (next == '}') { - auto token = NewToken(TokenType::RCurly); + auto token = NewToken(TokenType::kRCurly); Next(); return token; } else if (next == '[') { - auto token = NewToken(TokenType::LSquare); + auto token = NewToken(TokenType::kLSquare); Next(); return token; } else if (next == ']') { - auto token = NewToken(TokenType::RSquare); + auto token = NewToken(TokenType::kRSquare); Next(); return token; } else if (next == '!') { - auto token = NewToken(TokenType::Bang); + auto token = NewToken(TokenType::kBang); Next(); return token; } else if (next == '@') { - auto token = NewToken(TokenType::At); + auto token = NewToken(TokenType::kAt); Next(); return token; } else if (next == '?') { - auto token = NewToken(TokenType::Question); + auto token = NewToken(TokenType::kQuestion); Next(); return token; + } else if (MatchString("meta")) { + return TokenizeMetaRef(); + } else if (next == '#') { + return TokenizeAttr(); } else if (next == '%') { - auto token = NewToken(TokenType::Percent); + auto token = NewToken(TokenType::kPercent); Next(); + + std::stringstream number; + while (More() && IsDigit(Peek())) { + number << Next(); + } + + auto number_str = number.str(); + if (number_str.size()) { + auto num_tok = ParseNumber(true, false, number_str); + auto span = SpanFrom(token->span->line, token->span->column); + token = Token(span, TokenType::kGraph, num_tok->data); + } + return token; } else if (next == '/') { Next(); if (Peek() == '/') { - auto token = NewToken(TokenType::LineComment); + auto token = NewToken(TokenType::kLineComment); // Consume the / Next(); std::stringstream comment; @@ -319,10 +468,10 @@ struct Tokenizer { Next(); std::string comment; MatchComment(&comment); - auto token = NewToken(TokenType::Comment, tvm::String(comment)); + auto token = NewToken(TokenType::kComment, tvm::String(comment)); return token; } else { - return NewToken(TokenType::Division); + return NewToken(TokenType::kDivision); } } else if (IsIdentLetter(next)) { std::stringstream ss; @@ -343,57 +492,78 @@ struct Tokenizer { if (it != KEYWORD_TABLE.end()) { token_type = it->second; - if (token_type == TokenType::Match) { + if (token_type == TokenType::kMatch) { if (More() && Peek() == '?') { Next(); - token_type = TokenType::PartialMatch; + token_type = TokenType::kPartialMatch; } } } else { - token_type = TokenType::Identifier; + token_type = TokenType::kIdentifier; } - return Token(line, col, token_type, tvm::String(ss.str())); + auto span = SpanFrom(line, col); + return Token(span, token_type, tvm::String(ss.str())); } else { std::stringstream ss; while (More() && !IsWhitespace(Peek())) { ss << Next(); } - auto token = NewToken(TokenType::Unknown); + auto token = NewToken(TokenType::kUnknown); token->data = tvm::String(ss.str()); return token; } } void Tokenize() { + DLOG(INFO) << "tvm::parser::Tokenize"; while (this->More()) { auto token = TokenizeOnce(); CHECK(token.defined()); this->tokens.push_back(token); } - this->tokens.push_back(NewToken(TokenType::EndOfFile)); + this->tokens.push_back(NewToken(TokenType::kEndOfFile)); } - explicit Tokenizer(std::string& source) : pos(0), col(1), line(1), source(source), tokens() {} + explicit Tokenizer(DiagnosticContext* ctx, const SourceName& source_name, + const std::string& source) + : diag_ctx(ctx), + source_name(source_name), + pos(0), + col(1), + line(1), + source(source), + tokens() {} }; -std::vector Condense(const std::vector& tokens) { +std::vector Condense(const std::vector& tokens, Token* table) { std::vector out; + bool found_metadata = false; for (size_t i = 0; i < tokens.size(); i++) { auto current = tokens.at(i); switch (current->token_type) { - case TokenType::Percent: { + case TokenType::kMetadata: { + if (!found_metadata) { + found_metadata = true; + *table = current; + } else { + LOG(FATAL) << "duplicate metadata section"; + } + continue; + } + case TokenType::kPercent: { auto next = tokens.at(i + 1); - if (next->token_type == TokenType::Identifier) { + if (next->token_type == TokenType::kIdentifier) { // Match this token. i += 1; - auto tok = Token(current->line, current->column, TokenType::Local, next->data); + // TODO(@jroesch): merge spans + auto tok = Token(current->span, TokenType::kLocal, next->data); CHECK(tok.defined()); out.push_back(tok); - } else if (next->token_type == TokenType::Integer) { + } else if (next->token_type == TokenType::kInteger) { i += 1; - auto tok = Token(current->line, current->column, TokenType::Graph, next->data); + auto tok = Token(current->span, TokenType::kGraph, next->data); CHECK(tok.defined()); out.push_back(tok); } else { @@ -402,12 +572,13 @@ std::vector Condense(const std::vector& tokens) { } continue; } - case TokenType::At: { + case TokenType::kAt: { auto next = tokens.at(i + 1); - if (next->token_type == TokenType::Identifier) { + if (next->token_type == TokenType::kIdentifier) { // Match this token. i += 1; - auto tok = Token(current->line, current->column, TokenType::Global, next->data); + // TODO(@jroesch): merge spans + auto tok = Token(current->span, TokenType::kGlobal, next->data); CHECK(tok.defined()); out.push_back(tok); } else { @@ -416,17 +587,18 @@ std::vector Condense(const std::vector& tokens) { } continue; } - case TokenType::Identifier: { + case TokenType::kIdentifier: { std::string str = Downcast(current->data); Token tok; + // TODO(@jroesch): merge spans if (str == "True") { auto data = tvm::Integer(1); - tok = Token(current->line, current->column, TokenType::Boolean, data); + tok = Token(current->span, TokenType::kBoolean, data); } else if (str == "False") { auto data = tvm::Integer(0); - tok = Token(current->line, current->column, TokenType::Boolean, data); + tok = Token(current->span, TokenType::kBoolean, data); } else if (str == "_") { - tok = Token(current->line, current->column, TokenType::Underscore); + tok = Token(current->span, TokenType::kUnderscore); } else { tok = current; } @@ -443,14 +615,16 @@ std::vector Condense(const std::vector& tokens) { return out; } -std::vector Tokenize(std::string source) { - auto tokenizer = Tokenizer(source); +std::pair, Token> Tokenize(DiagnosticContext* ctx, const SourceName& source_name, + const std::string& source) { + auto tokenizer = Tokenizer(ctx, source_name, source); tokenizer.Tokenize(); - auto tokens = Condense(tokenizer.tokens); + Token meta_table(Span(), TokenType::kUnknown, ObjectRef()); + auto tokens = Condense(tokenizer.tokens, &meta_table); for (auto token : tokens) { CHECK(token.defined()); } - return tokens; + return {tokens, meta_table}; } } // namespace parser diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ee11548edf29c..1b09052a63d84 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -39,6 +39,7 @@ #include #include "../ir/attr_functor.h" +#include "../parser/meta_ref.h" #include "../relay/analysis/dependency_graph.h" #include "doc.h" #include "meta_data.h" @@ -246,6 +247,7 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { // determine whether to inline bool inline_expr = AlwaysInline(expr); + if (try_inline) { inline_expr |= IsUnique(expr); } @@ -254,6 +256,7 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { if (it != memo_.end()) return it->second; Doc printed_expr; + if (meta) { printed_expr = meta_->GetMetaNode(GetRef(expr.get())); } else if (!inline_expr && expr.as()) { @@ -272,7 +275,7 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); + doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine(); // Memoization is done in AllocVar. return memo_[expr]; } else if (inline_expr) { @@ -721,6 +724,8 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { Doc printed_attr; if (value.as()) { printed_attr << "?"; + } else if (auto str_obj = value.as()) { + printed_attr << Doc::StrLiteral(GetRef(str_obj)); } else if (meta) { printed_attr = meta_->GetMetaNode(Downcast(value)); } else { diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 2993d38234ead..1e882db1fd61c 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -31,17 +31,7 @@ namespace tvm { -static const char* kSemVer = "v0.0.4"; - -// TODO(tvm-team): split into files, related: arith/analyzer.h -// -// - text_printer.h (common header) -// - text_printer.cc (prints modules dispatch into relay and tir files) -// - type_text_printer.cc(specific printing logics for types, -// can also consider put under type_text_printer) -// - Implements AsText -// - relay_text_printer.cc (specific printing logics for relay) -// - tir_text_printer.cc (specific printing logics for TIR) +static const char* kSemVer = "0.0.5"; Doc TextPrinter::PrintMod(const IRModule& mod) { Doc doc; @@ -77,14 +67,14 @@ Doc TextPrinter::PrintMod(const IRModule& mod) { String PrettyPrint(const ObjectRef& node) { Doc doc; - doc << TextPrinter(false, nullptr).PrintFinal(node); + doc << TextPrinter(false, nullptr, false).PrintFinal(node); return doc.str(); } String AsText(const ObjectRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { Doc doc; - doc << kSemVer << Doc::NewLine(); + doc << "#[version = \"" << kSemVer << "\"]" << Doc::NewLine(); runtime::TypedPackedFunc ftyped = nullptr; if (annotate != nullptr) { ftyped = runtime::TypedPackedFunc( diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 7baa3878ff720..b65b03c380635 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -355,14 +355,20 @@ namespace tvm { class TextPrinter { public: explicit TextPrinter(bool show_meta_data, - const runtime::TypedPackedFunc& annotate) + const runtime::TypedPackedFunc& annotate, + bool show_warning = true) : show_meta_data_(show_meta_data), + show_warning_(show_warning), annotate_(annotate), relay_text_printer_(show_meta_data, &meta_, annotate), tir_text_printer_(show_meta_data, &meta_) {} /*! \brief whether show meta data */ bool show_meta_data_; + + /*! \brief whether show the meta data warning message */ + bool show_warning_; + /*! \brief meta data context */ TextMetaDataContext meta_; /*! \brief additional comment function */ @@ -385,10 +391,12 @@ class TextPrinter { if (!meta_.empty()) { doc << Doc::NewLine(); if (show_meta_data_) { - // append meta data in the end. - doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); - } else { - doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; + doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection(); + } else if (show_warning_) { + doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine() + << " * If you would like to see the full metadata section you can set the " + << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine() + << " */"; } } return doc; diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index d808351e841c8..ba9743cc35bf7 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -116,11 +116,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")"; }); -Match::Match(Expr data, tvm::Array clauses, bool complete) { +Match::Match(Expr data, tvm::Array clauses, bool complete, Span span) { ObjectPtr n = make_object(); n->data = std::move(data); n->clauses = std::move(clauses); n->complete = complete; + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 1d9e3cef12b70..237cb35d84556 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -30,9 +30,10 @@ namespace relay { using tvm::ReprPrinter; using namespace tvm::runtime; -Constant::Constant(runtime::NDArray data) { +Constant::Constant(runtime::NDArray data, Span span) { ObjectPtr n = make_object(); n->data = std::move(data); + n->span = std::move(span); data_ = std::move(n); } @@ -63,9 +64,10 @@ TensorType ConstantNode::tensor_type() const { return TensorType(shape, dtype); } -Tuple::Tuple(tvm::Array fields) { +Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); + n->span = std::move(span); data_ = std::move(n); } @@ -81,10 +83,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Tuple(" << node->fields << ")"; }); -Var::Var(Id vid, Type type_annotation) { +Var::Var(Id vid, Type type_annotation, Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); n->type_annotation = std::move(type_annotation); + n->span = std::move(span); data_ = std::move(n); } @@ -105,12 +108,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); -Call::Call(Expr op, Array args, Attrs attrs, Array type_args) { +Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span span) { ObjectPtr n = make_object(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); n->type_args = std::move(type_args); + n->span = std::move(span); data_ = std::move(n); } @@ -128,11 +132,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->type_args << ")"; }); -Let::Let(Var var, Expr value, Expr body) { +Let::Let(Var var, Expr value, Expr body, Span span) { ObjectPtr n = make_object(); n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); + n->span = std::move(span); data_ = std::move(n); } @@ -148,11 +153,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")"; }); -If::If(Expr cond, Expr true_branch, Expr false_branch) { +If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { ObjectPtr n = make_object(); n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); + n->span = std::move(span); data_ = std::move(n); } @@ -170,10 +176,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->false_branch << ")"; }); -TupleGetItem::TupleGetItem(Expr tuple, int index) { +TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { ObjectPtr n = make_object(); n->tuple = std::move(tuple); n->index = index; + n->span = std::move(span); data_ = std::move(n); } @@ -189,9 +196,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; }); -RefCreate::RefCreate(Expr value) { +RefCreate::RefCreate(Expr value, Span span) { ObjectPtr n = make_object(); n->value = std::move(value); + n->span = std::move(span); data_ = std::move(n); } @@ -207,9 +215,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefCreateNode(" << node->value << ")"; }); -RefRead::RefRead(Expr ref) { +RefRead::RefRead(Expr ref, Span span) { ObjectPtr n = make_object(); n->ref = std::move(ref); + n->span = std::move(span); data_ = std::move(n); } @@ -223,10 +232,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefReadNode(" << node->ref << ")"; }); -RefWrite::RefWrite(Expr ref, Expr value) { +RefWrite::RefWrite(Expr ref, Expr value, Span span) { ObjectPtr n = make_object(); n->ref = std::move(ref); n->value = std::move(value); + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index ad15453e3058f..cbc41d225d4b5 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -199,7 +199,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { if (op->type_annotation.defined()) { auto type = this->VisitType(op->type_annotation); if (!op->type_annotation.same_as(type)) { - return Var(op->vid, type); + return Var(op->vid, type, op->span); } } // default case return self. @@ -224,7 +224,7 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { if (all_fields_unchanged) { return GetRef(op); } else { - return Tuple(fields); + return Tuple(fields, op->span); } } @@ -253,7 +253,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { body.same_as(op->body)) { return GetRef(op); } else { - return Function(params, body, ret_type, ty_params, op->attrs); + return Function(params, body, ret_type, ty_params, op->attrs, op->span); } } @@ -278,7 +278,7 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { if (unchanged) { return GetRef(call_node); } else { - return Call(new_op, call_args, call_node->attrs, ty_args); + return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); } } @@ -290,7 +290,7 @@ Expr ExprMutator::VisitExpr_(const LetNode* op) { if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return Let(var, value, body); + return Let(var, value, body, op->span); } } @@ -302,16 +302,16 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { op->false_branch.same_as(false_b)) { return GetRef(op); } else { - return If(guard, true_b, false_b); + return If(guard, true_b, false_b, op->span); } } -Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { - auto t = this->Mutate(g->tuple); - if (g->tuple == t) { - return GetRef(g); +Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) { + auto t = this->Mutate(get_item->tuple); + if (get_item->tuple == t) { + return GetRef(get_item); } else { - return TupleGetItem(t, g->index); + return TupleGetItem(t, get_item->index, get_item->span); } } @@ -320,7 +320,7 @@ Expr ExprMutator::VisitExpr_(const RefCreateNode* op) { if (value.same_as(op->value)) { return GetRef(op); } else { - return RefCreate(value); + return RefCreate(value, op->span); } } @@ -329,7 +329,7 @@ Expr ExprMutator::VisitExpr_(const RefReadNode* op) { if (ref.same_as(op->ref)) { return GetRef(op); } else { - return RefRead(ref); + return RefRead(ref, op->span); } } @@ -339,7 +339,7 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { if (ref.same_as(op->ref) && value.same_as(op->value)) { return GetRef(op); } else { - return RefWrite(ref, value); + return RefWrite(ref, value, op->span); } } @@ -355,10 +355,11 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) { } Expr data = Mutate(m->data); unchanged &= data.same_as(m->data); + if (unchanged) { return GetRef(m); } - return Match(data, clauses, m->complete); + return Match(data, clauses, m->complete, m->span); } Clause ExprMutator::VisitClause(const Clause& c) { @@ -386,22 +387,25 @@ void ExprVisitor::VisitExpr(const Expr& expr) { } void ExprVisitor::VisitExpr_(const VarNode* op) { + this->VisitSpan(op->span); if (op->type_annotation.defined()) { this->VisitType(op->type_annotation); } } -void ExprVisitor::VisitExpr_(const GlobalVarNode* op) {} +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); } -void ExprVisitor::VisitExpr_(const ConstantNode* op) {} +void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); } void ExprVisitor::VisitExpr_(const TupleNode* op) { + this->VisitSpan(op->span); for (auto field : op->fields) { this->VisitExpr(field); } } void ExprVisitor::VisitExpr_(const FunctionNode* op) { + this->VisitSpan(op->span); for (auto param : op->params) { this->VisitExpr(param); } @@ -410,6 +414,7 @@ void ExprVisitor::VisitExpr_(const FunctionNode* op) { } void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->op); for (auto ty_arg : op->type_args) { @@ -422,12 +427,14 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { } void ExprVisitor::VisitExpr_(const LetNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->value); this->VisitExpr(op->var); this->VisitExpr(op->body); } void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->cond); this->VisitExpr(op->true_branch); this->VisitExpr(op->false_branch); @@ -435,18 +442,29 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { void ExprVisitor::VisitExpr_(const OpNode* op) { return; } -void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->tuple); +} -void ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitExpr(op->value); } +void ExprVisitor::VisitExpr_(const RefCreateNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->value); +} -void ExprVisitor::VisitExpr_(const RefReadNode* op) { this->VisitExpr(op->ref); } +void ExprVisitor::VisitExpr_(const RefReadNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->ref); +} void ExprVisitor::VisitExpr_(const RefWriteNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->ref); this->VisitExpr(op->value); } void ExprVisitor::VisitExpr_(const ConstructorNode* op) { + // TODO(@jroesch): visit spans for (const Type& t : op->inputs) { this->VisitType(t); } @@ -454,6 +472,7 @@ void ExprVisitor::VisitExpr_(const ConstructorNode* op) { } void ExprVisitor::VisitExpr_(const MatchNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->data); for (const Clause& c : op->clauses) { this->VisitClause(c); @@ -461,6 +480,7 @@ void ExprVisitor::VisitExpr_(const MatchNode* op) { } void ExprVisitor::VisitClause(const Clause& op) { + // TODO(@jroesch): visit spans this->VisitPattern(op->lhs); this->VisitExpr(op->rhs); } @@ -469,6 +489,8 @@ void ExprVisitor::VisitPattern(const Pattern& p) { return; } void ExprVisitor::VisitType(const Type& t) { return; } +void ExprVisitor::VisitSpan(const Span& span) { return; } + // visitor to implement apply class ExprApplyVisit : public ExprVisitor { public: diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 5312e6d48447c..1439e8b59cf07 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -27,7 +27,7 @@ namespace tvm { namespace relay { Function::Function(tvm::Array params, Expr body, Type ret_type, - tvm::Array type_params, DictAttrs attrs) { + tvm::Array type_params, DictAttrs attrs, Span span) { ObjectPtr n = make_object(); CHECK(params.defined()); CHECK(type_params.defined()); @@ -36,6 +36,7 @@ Function::Function(tvm::Array params, Expr body, Type ret_type, n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); n->attrs = std::move(attrs); + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 45e1af1c960f7..7182f0e96f0f4 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -26,16 +26,16 @@ * most efficient code we need to obtain type information for the * IR. * - * Like computation graphs the IR leaves most type information - * implicit and relies performing analysis of the program to - * generate this information. + * Similar to previous computation graph based IRs, the Relay IR leaves + * type information implicit and computes types by performing program + * analysis. * - * This pass given an expression `e` will infer a type `t` for - * the expression simultaneous checking the property `e : t` - * (i.e we can show e has type t). + * Given an expression `e` this pass infers a type `t` for + * the expression as well as simultaneously checking the property `e : t` + * (i.e., we can show e has type t). * - * If we can not infer a type or there are conflicting typing - * constraints we will trigger an error. + * If we can not infer a type or there is a conflicting + * constraint it will emit errors. */ #include #include diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index a6f23107c42cf..d23fa3a1f7aa2 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -256,6 +256,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::vector control_deps; // subgraphs std::vector subgraphs; + // JSON Loader void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) { int bitmask = 0; diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 0c3ab601e04ab..5f0445134dea1 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -37,11 +37,6 @@ dist .node_repl_history node_modules -# Relay parser: they are generated by ANTLR. -RelayLexer.py -RelayParser.py -RelayVisitor.py - # Specific files package-list MANIFEST diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index fed257fafd4a2..b53423a2f4c14 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -31,10 +31,12 @@ def check_json_roundtrip(node): # Span def test_span(): - span = relay.Span(None, 1, 1) - assert span.source == None + span = relay.Span(None, 1, 2, 3, 4) + assert span.source_name == None assert span.line == 1 - assert span.column == 1 + assert span.end_line == 2 + assert span.column == 3 + assert span.end_column == 4 assert span.same_as(span) assert span == span assert isinstance(span, relay.base.Span) @@ -43,9 +45,11 @@ def test_span(): # span is not a node so we can't use graph_equal # to test the round trip back = tvm.ir.load_json(tvm.ir.save_json(span)) - assert back.source == span.source + assert back.source_name == span.source_name assert back.line == span.line + assert back.end_line == span.end_line assert back.column == span.column + assert back.end_column == span.end_column def test_constant(): diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 1e4fe6b668307..3fcc7dab5bcd4 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -17,13 +17,15 @@ import tvm from tvm import te from tvm import relay +import tvm.relay.testing import pytest from numpy import isclose from typing import Union from functools import wraps -raises_parse_error = pytest.mark.xfail(raises=tvm._ffi.base.TVMError) -SEMVER = "v0.0.4" + + +SEMVER = "#[version = \"0.0.5\"]\n" BINARY_OPS = { "*": relay.multiply, @@ -73,29 +75,40 @@ def assert_graph_equal(lhs, rhs): def graph_equal(lhs, rhs): return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True) +def roundtrip_expr(expr): + text = tvm.relay.Expr.astext(expr, show_meta_data=False) + x = tvm.parser.parse_expr(text) + assert_graph_equal(x, expr) +# Testing Utilities for expressions. def roundtrip(expr): - x = relay.fromtext(expr.astext()) + x = tvm.parser.fromtext(expr.astext()) assert_graph_equal(x, expr) - def parse_text(code): - expr = relay.fromtext(SEMVER + "\n" + code) - roundtrip(expr) + expr = tvm.parser.parse_expr(code) + roundtrip_expr(expr) return expr - def parses_as(code, expr): # type: (str, relay.Expr) -> bool parsed = parse_text(code) result = graph_equal(parsed, expr) return result +# Testing Utilities for full modules. +def parse_module(code): + mod = tvm.parser.parse(SEMVER + code) + roundtrip(mod) + return mod def assert_parses_as(code, expr): parsed = parse_text(code) assert_graph_equal(parsed, expr) +def assert_parse_module_as(code, mod): + parsed = parse_module(code) + assert_graph_equal(parsed, mod) def get_scalar(x): # type: (relay.Constant) -> (Union[float, int, bool]) @@ -176,7 +189,8 @@ def test_bool_literal(): def test_negative(): - assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call) + # need to handle parsing non-literal operations + # assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call) assert get_scalar(parse_text("--10")) == 10 assert get_scalar(parse_text("---10")) == -10 @@ -198,15 +212,7 @@ def test_op_assoc(): assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1")) assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))")) - -@pytest.mark.skip def test_vars(): - # temp vars won't work b/c they start with a digit - # # temp var - # temp_var = parse_text("%1") - # assert isinstance(temp_var, relay.Var) - # assert temp_var.name == "1" - # var var = parse_text("let %foo = (); %foo") assert isinstance(var.body, relay.Var) @@ -218,9 +224,20 @@ def test_vars(): assert global_var.name_hint == "foo" # operator id - op = parse_text("foo") + op = parse_text("add") assert isinstance(op, tvm.ir.Op) - assert op.name == "foo" + assert op.name == "add" + + # operator id with prefix + op = parse_text("nn.global_avg_pool2d") + assert isinstance(op, tvm.ir.Op) + assert op.name == "nn.global_avg_pool2d" + +def test_meta_ref(): + with pytest.raises(tvm.error.DiagnosticError): + meta_op = parse_text("meta[type_key][1337]") + assert meta_op.attrs.node_type_key == "type_key" + assert meta_op.attrs.node_index == 1337 def test_let(): @@ -253,7 +270,7 @@ def test_let(): def test_seq(): assert_parses_as( - "();; ()", + "(); ()", relay.Let( _, UNIT, @@ -278,19 +295,17 @@ def test_graph(): ) -@raises_parse_error -def test_graph_wrong_order(): - parse_text("%1 = (); %1") +def test_graph_single(): + assert_parses_as("%1 = (); %1", relay.Tuple([])) - -@raises_parse_error def test_let_global_var(): - parse_text("let @x = 1; ()") + with pytest.raises(tvm.error.DiagnosticError): + parse_text("let @x = 1; ()") -@raises_parse_error def test_let_op(): - parse_text("let x = 1; ()") + with pytest.raises(tvm.error.DiagnosticError): + parse_text("let x = 1; ()") def test_tuple(): @@ -348,16 +363,18 @@ def test_func(): ) ) - # attributes - assert_parses_as( - "fn (n=5) { () }", - relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5))) - ) + # Refactor the attribute syntax and printing. + # + # # attributes + # assert_parses_as( + # "fn (n=5) { () }", + # relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5))) + # ) # TODO(@jmp): Crashes if %x isn't annnotated. def test_defn(): - id_defn = parse_text( + id_defn = parse_module( """ def @id(%x: int32) -> int32 { %x @@ -367,7 +384,7 @@ def @id(%x: int32) -> int32 { def test_recursive_call(): - id_defn = parse_text( + id_defn = parse_module( """ def @id(%x: int32) -> int32 { @id(%x) @@ -393,18 +410,18 @@ def test_ifelse(): ) -@raises_parse_error def test_ifelse_scope(): - parse_text( - """ - if (True) { - let %x = (); - () - } else { - %x - } - """ - ) + with pytest.raises(tvm.error.DiagnosticError): + parse_text( + """ + if (True) { + let %x = (); + () + } else { + %x + } + """ + ) def test_call(): @@ -487,40 +504,39 @@ def test_call(): ) ) - # TODO(@jmp): re-enable after sequence parsing improvements # curried function - # curried_mult = relay.Var("curried_mult") - # assert_parses_as( - # """ - # let %curried_mult = - # fn (%x) { - # fn (%y) { - # %x * %y - # } - # }; - # %curried_mult(0); - # %curried_mult(0)(0) - # """, - # relay.Let( - # curried_mult, - # relay.Function( - # [X], - # relay.Function( - # [Y], - # relay.multiply(X, Y), - # None, - # [] - # ), - # None, - # [] - # ), - # relay.Let( - # _, - # relay.Call(curried_mult, [relay.const(0)], None, None), - # relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) - # ) - # ) - # ) + curried_mult = relay.Var("curried_mult") + assert_parses_as( + """ + let %curried_mult = + fn (%x) { + fn (%y) { + %x * %y + } + }; + %curried_mult(0); + %curried_mult(0)(0) + """, + relay.Let( + curried_mult, + relay.Function( + [X], + relay.Function( + [Y], + relay.multiply(X, Y), + None, + [] + ), + None, + [] + ), + relay.Let( + _, + relay.Call(curried_mult, [relay.const(0)], None, None), + relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) + ) + ) + ) # op assert_parses_as( @@ -655,7 +671,7 @@ def test_adt_defn(): [], [relay.Constructor("Nil", [], glob_typ_var)]) mod[glob_typ_var] = prog - assert_parses_as( + assert_parse_module_as( """ type Ayy { Nil } """, @@ -669,7 +685,7 @@ def test_empty_adt_defn(): glob_typ_var = relay.GlobalTypeVar("Ayy") prog = relay.TypeData(glob_typ_var, [], []) mod[glob_typ_var] = prog - assert_parses_as( + assert_parse_module_as( """ type Ayy { } """, @@ -690,7 +706,7 @@ def test_multiple_cons_defn(): relay.Constructor("Nil", [], list_var), ]) mod[list_var] = prog - assert_parses_as(LIST_DEFN, mod) + assert_parse_module_as(LIST_DEFN, mod) def test_multiple_type_param_defn(): @@ -706,7 +722,7 @@ def test_multiple_type_param_defn(): ]) mod = tvm.IRModule() mod[glob_typ_var] = prog - assert_parses_as( + assert_parse_module_as( """ type Either[A, B] { Left(A), @@ -740,7 +756,7 @@ def test_match(): input_var = relay.Var("xs", input_type) rest_var = relay.Var("rest") cons_case = relay.Let( - _, + relay.var("", type_annotation=None), UNIT, relay.add(relay.const(1), relay.Call(length_var, [rest_var]))) body = relay.Match(input_var, @@ -762,14 +778,14 @@ def test_match(): ) mod[length_var] = length_func - assert_parses_as( + assert_parse_module_as( """ %s def @length[A](%%xs: List[A]) -> int32 { %s (%%xs) { - Cons(_, %%rest) => { - ();; + Cons(_, %%rest : List[A]) => { + (); 1 + @length(%%rest) }, Nil => 0, @@ -803,7 +819,7 @@ def test_adt_cons_expr(): ) mod[make_singleton_var] = make_singleton_func - assert_parses_as( + assert_parse_module_as( """ %s @@ -815,52 +831,51 @@ def @make_singleton(%%x: int32) -> List[int32] { ) -@raises_parse_error def test_duplicate_adt_defn(): - parse_text( - """ - %s + with pytest.raises(tvm.error.DiagnosticError): + parse_module( + """ + %s - type List[A] { - Cons(A, List[A]), - Nil, - } - """ % LIST_DEFN - ) + type List[A] { + Cons(A, List[A]), + Nil, + } + """ % LIST_DEFN + ) -@raises_parse_error def test_duplicate_adt_cons(): - parse_text( - """ - type Ayy { Lmao } - type Haha { Lmao } - """ - ) + with pytest.raises(tvm.error.DiagnosticError): + parse_text( + """ + type Ayy { Lmao } + type Haha { Lmao } + """ + ) -@raises_parse_error def test_duplicate_adt_cons_defn(): - parse_text( - """ - type Ayy { Lmao } - type Lmao { Ayy } - """ - ) + with pytest.raises(tvm.error.DiagnosticError): + parse_text( + """ + type Ayy { Lmao } + type Lmao { Ayy } + """ + ) -@raises_parse_error def test_duplicate_global_var(): - parse_text( - """ - def @id[A](%%x: A) -> A { x } - def @id[A](%%x: A) -> A { x } - """ - ) + with pytest.raises(tvm.error.DiagnosticError): + parse_text( + """ + def @id[A](%x: A) -> A { x } + def @id[A](%x: A) -> A { x } + """ + ) def test_extern_adt_defn(): - # TODO(weberlo): update this test once extern is implemented mod = tvm.IRModule() extern_var = relay.GlobalTypeVar("T") @@ -868,48 +883,45 @@ def test_extern_adt_defn(): extern_def = relay.TypeData(extern_var, [typ_var], []) mod[extern_var] = extern_def - assert_parses_as( + assert_parse_module_as( """ extern type T[A] """, mod ) + def test_import_grad(): mod = tvm.IRModule() mod.import_from_std("gradient.rly") +def test_resnet(): + mod, _ = relay.testing.resnet.get_workload() + text = mod.astext() + parsed_mod = tvm.parser.parse(text) + tvm.ir.assert_structural_equal(mod, parsed_mod) + +def inline_params(mod, params): + main_fn = mod["main"] + str_to_var = {} + for param in main_fn.params: + str_to_var[param.name_hint] = param + + bind_map = {} + for param in params: + bind_map[str_to_var[param]] = relay.const(params[param]) + + body = relay.bind(main_fn.body, bind_map) + main_fn = relay.Function(relay.analysis.free_vars(body), body) + mod["main_fn"] = main_fn + return mod + +def test_resnet_inlined_params(): + mod, params = relay.testing.resnet.get_workload() + mod = inline_params(mod, params) + text = mod.astext() + parsed_mod = tvm.parser.parse(text) + tvm.ir.assert_structural_equal(mod, parsed_mod) + if __name__ == "__main__": - test_graph() - test_comments() - test_int_literal() - test_float_literal() - test_bool_literal() - test_negative() - test_bin_op() - test_parens() - test_op_assoc() - test_let() - test_seq() - test_tuple() - test_func() - test_defn() - test_recursive_call() - test_ifelse() - test_call() - test_incomplete_type() - test_builtin_types() - test_tensor_type() - test_function_type() - test_tuple_type() - test_adt_defn() - test_empty_adt_defn() - test_multiple_cons_defn() - test_multiple_type_param_defn() - test_match() - test_adt_cons_expr() - test_duplicate_adt_defn() - test_duplicate_adt_cons() - test_duplicate_adt_cons_defn() - test_duplicate_global_var() - test_extern_adt_defn() - test_import_grad() + import sys + pytest.main(sys.argv) diff --git a/tests/python/relay/test_ir_parser2.py b/tests/python/relay/test_ir_parser2.py deleted file mode 100644 index 23ba1fa850e50..0000000000000 --- a/tests/python/relay/test_ir_parser2.py +++ /dev/null @@ -1,891 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te -from tvm import relay -import pytest -from numpy import isclose -from typing import Union -from functools import wraps -raises_parse_error = pytest.mark.xfail(raises=tvm._ffi.base.TVMError) - -SEMVER = "v0.0.4" - -BINARY_OPS = { - "*": relay.multiply, - "/": relay.divide, - "+": relay.add, - "-": relay.subtract, - "<": relay.less, - ">": relay.greater, - "<=": relay.less_equal, - ">=": relay.greater_equal, - "==": relay.equal, - "!=": relay.not_equal, -} - -TYPES = { - "int8", - "int16", - "int32", - "int64", - - "uint8", - "uint16", - "uint32", - "uint64", - - "float16", - "float32", - "float64", - - "bool", - - "int8x4", - "uint1x4", - "float16x4", -} - -LIST_DEFN = """ -type List[A] { - Cons(A, List[A]), - Nil, -} -""" - -def assert_graph_equal(lhs, rhs): - tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True) - -def graph_equal(lhs, rhs): - return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True) - - -def roundtrip_expr(expr): - x = tvm.parser.parse_expr(str(str(expr))) - assert_graph_equal(x, expr) - -def roundtrip(expr): - x = tvm.parser.fromtext(expr.astext()) - assert_graph_equal(x, expr) - -def parse_text(code): - expr = tvm.parser.parse_expr(code) - roundtrip_expr(expr) - return expr - - -def parses_as(code, expr): - # type: (str, relay.Expr) -> bool - parsed = parse_text(code) - result = graph_equal(parsed, expr) - return result - -def parse_module(code): - mod = tvm.parser.parse(code) - roundtrip(mod) - return mod - - -def assert_parses_as(code, expr): - parsed = parse_text(code) - assert_graph_equal(parsed, expr) - -def assert_parse_module_as(code, mod): - parsed = parse_module(code) - assert_graph_equal(parsed, mod) - -def get_scalar(x): - # type: (relay.Constant) -> (Union[float, int, bool]) - return x.data.asnumpy().item() - -int32 = relay.scalar_type("int32") - -_ = relay.Var("_") -X = relay.Var("x") -Y = relay.Var("y") -X_ANNO = relay.Var("x", int32) -Y_ANNO = relay.Var("y", int32) - -UNIT = relay.Tuple([]) - - -def test_comments(): - assert_parses_as( - """ - // This is a line comment! - () - """, - UNIT - ) - - assert_parses_as( - """ - /* This is a block comment! - This is still a block comment! - */ - () - """, - UNIT - ) - - assert_parses_as( - """ - /* This is a block comment! - /*Block comment is recursive!*/ - */ - () - """, - UNIT - ) - - -def test_int_literal(): - assert isinstance(parse_text("1"), relay.Constant) - assert isinstance(parse_text("1").data, tvm.nd.NDArray) - - assert get_scalar(parse_text("1")) == 1 - assert get_scalar(parse_text("10")) == 10 - assert get_scalar(parse_text("0")) == 0 - assert get_scalar(parse_text("-100")) == -100 - assert get_scalar(parse_text("-05")) == -5 - - -def test_float_literal(): - assert get_scalar(parse_text("1.0f")) == 1.0 - assert isclose(get_scalar(parse_text("1.56667f")), 1.56667) - assert get_scalar(parse_text("0.0f")) == 0.0 - assert get_scalar(parse_text("-10.0f")) == -10.0 - - # scientific notation - assert isclose(get_scalar(parse_text("1e-1f")), 1e-1) - assert get_scalar(parse_text("1e+1f")) == 1e+1 - assert isclose(get_scalar(parse_text("1E-1f")), 1E-1) - assert get_scalar(parse_text("1E+1f")) == 1E+1 - assert isclose(get_scalar(parse_text("1.0e-1f")), 1.0e-1) - assert get_scalar(parse_text("1.0e+1f")) == 1.0e+1 - assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0E-1) - assert get_scalar(parse_text("1.0E+1f")) == 1.0E+1 - - -def test_bool_literal(): - assert get_scalar(parse_text("True")) == True - assert get_scalar(parse_text("False")) == False - - -def test_negative(): - # need to handle parsing non-literal operations - # assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call) - assert get_scalar(parse_text("--10")) == 10 - assert get_scalar(parse_text("---10")) == -10 - - -def test_bin_op(): - for bin_op in BINARY_OPS.keys(): - assert_parses_as( - "1 {} 1".format(bin_op), - BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) - ) - - -def test_parens(): - assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1")) - assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)")) - - -def test_op_assoc(): - assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1")) - assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))")) - - -def test_vars(): - # var - var = parse_text("let %foo = (); %foo") - assert isinstance(var.body, relay.Var) - assert var.body.name_hint == "foo" - - # global var - global_var = parse_text("@foo") - assert isinstance(global_var, relay.GlobalVar) - assert global_var.name_hint == "foo" - - # operator id - op = parse_text("add") - assert isinstance(op, tvm.ir.Op) - assert op.name == "add" - - -def test_let(): - assert_parses_as( - "let %x = 1; ()", - relay.Let( - X, - relay.const(1), - UNIT - ) - ) - - assert_parses_as( - """ - let %x = 1; - let %y = 2; - () - """, - relay.Let( - X, - relay.const(1), - relay.Let( - Y, - relay.const(2), - UNIT - ) - ) - ) - - -def test_seq(): - assert_parses_as( - "(); ()", - relay.Let( - _, - UNIT, - UNIT) - ) - - assert_parses_as( - "let %_ = 1; ()", - relay.Let( - X, - relay.const(1), - UNIT - ) - ) - - -def test_graph(): - code = "%0 = (); %1 = 1; (%0, %0, %1)" - assert_parses_as( - code, - relay.Tuple([UNIT, UNIT, relay.const(1)]) - ) - - -@raises_parse_error -def test_graph_wrong_order(): - parse_text("%1 = (); %1") - - -@raises_parse_error -def test_let_global_var(): - parse_text("let @x = 1; ()") - - -@raises_parse_error -def test_let_op(): - parse_text("let x = 1; ()") - - -def test_tuple(): - assert_parses_as("()", relay.Tuple([])) - - assert_parses_as("(0,)", relay.Tuple([relay.const(0)])) - - assert_parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)])) - - assert_parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) - - -def test_func(): - # 0 args - assert_parses_as( - "fn () { 0 }", - relay.Function( - [], - relay.const(0), - None, - [] - ) - ) - - # 1 arg - assert_parses_as( - "fn (%x) { %x }", - relay.Function( - [X], - X, - None, - [] - ) - ) - - # 2 args - assert_parses_as( - "fn (%x, %y) { %x + %y }", - relay.Function( - [X, Y], - relay.add(X, Y), - None, - [] - ) - ) - - # annotations - assert_parses_as( - "fn (%x: int32) -> int32 { %x }", - relay.Function( - [X_ANNO], - X_ANNO, - int32, - [] - ) - ) - - # Refactor the attribute syntax and printing. - # - # # attributes - # assert_parses_as( - # "fn (n=5) { () }", - # relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5))) - # ) - - -# TODO(@jmp): Crashes if %x isn't annnotated. -def test_defn(): - id_defn = parse_module( - """ - def @id(%x: int32) -> int32 { - %x - } - """) - assert isinstance(id_defn, tvm.IRModule) - - -def test_recursive_call(): - id_defn = parse_module( - """ - def @id(%x: int32) -> int32 { - @id(%x) - } - """) - assert isinstance(id_defn, tvm.IRModule) - - -def test_ifelse(): - assert_parses_as( - """ - if (True) { - 0 - } else { - 1 - } - """, - relay.If( - relay.const(True), - relay.const(0), - relay.const(1) - ) - ) - - -@raises_parse_error -def test_ifelse_scope(): - parse_text( - """ - if (True) { - let %x = (); - () - } else { - %x - } - """ - ) - - -def test_call(): - # select right function to call: simple ident case - id_func = relay.Var("id") - assert_parses_as( - """ - let %id = fn (%x) { %x }; - 10 * %id(10) - """, - relay.Let( - id_func, - relay.Function([X], X, None, []), - relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)])) - ) - ) - - # 0 args - constant = relay.Var("constant") - assert_parses_as( - """ - let %constant = fn () { 0 }; - %constant() - """, - relay.Let( - constant, - relay.Function([], relay.const(0), None, []), - relay.Call(constant, [], None, None) - ) - ) - - # 1 arg - id_var = relay.Var("id") - assert_parses_as( - """ - let %id = fn (%x) { %x }; - %id(1) - """, - relay.Let( - id_var, - relay.Function([X], X, None, []), - relay.Call(id_var, [relay.const(1)], None, None) - ) - ) - - # 2 args - multiply = relay.Var("multiply") - assert_parses_as( - """ - let %multiply = fn (%x, %y) { %x * %y }; - %multiply(0, 0) - """, - relay.Let( - multiply, - relay.Function( - [X, Y], - relay.multiply(X, Y), - None, - [] - ), - relay.Call(multiply, [relay.const(0), relay.const(0)], None, None) - ) - ) - - # anonymous function - assert_parses_as( - """ - (fn (%x) { %x })(0) - """, - relay.Call( - relay.Function( - [X], - X, - None, - [] - ), - [relay.const(0)], - None, - None - ) - ) - - # curried function - curried_mult = relay.Var("curried_mult") - assert_parses_as( - """ - let %curried_mult = - fn (%x) { - fn (%y) { - %x * %y - } - }; - %curried_mult(0); - %curried_mult(0)(0) - """, - relay.Let( - curried_mult, - relay.Function( - [X], - relay.Function( - [Y], - relay.multiply(X, Y), - None, - [] - ), - None, - [] - ), - relay.Let( - _, - relay.Call(curried_mult, [relay.const(0)], None, None), - relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) - ) - ) - ) - - # op - assert_parses_as( - "abs(1)", - relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) - ) - -# Types - - -def test_incomplete_type(): - assert_parses_as( - "let %_ : _ = (); ()", - relay.Let( - _, - UNIT, - UNIT - ) - ) - - -def test_builtin_types(): - for builtin_type in TYPES: - parse_text("let %_ : {} = (); ()".format(builtin_type)) - - -def test_tensor_type(): - assert_parses_as( - "let %_ : Tensor[(), float32] = (); ()", - relay.Let( - relay.Var("_", relay.TensorType((), "float32")), - UNIT, - UNIT - ) - ) - - assert_parses_as( - "let %_ : Tensor[(1), float32] = (); ()", - relay.Let( - relay.Var("_", relay.TensorType((1,), "float32")), - UNIT, - UNIT - ) - ) - - assert_parses_as( - "let %_ : Tensor[(1, 1), float32] = (); ()", - relay.Let( - relay.Var("_", relay.TensorType((1, 1), "float32")), - UNIT, - UNIT - ) - ) - - -def test_function_type(): - assert_parses_as( - """ - let %_: fn () -> int32 = fn () -> int32 { 0 }; () - """, - relay.Let( - relay.Var("_", relay.FuncType([], int32, [], [])), - relay.Function([], relay.const(0), int32, []), - UNIT - ) - ) - - assert_parses_as( - """ - let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () - """, - relay.Let( - relay.Var("_", relay.FuncType([int32], int32, [], [])), - relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), - UNIT - ) - ) - - assert_parses_as( - """ - let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () - """, - relay.Let( - relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), - relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), - UNIT - ) - ) - - -def test_tuple_type(): - assert_parses_as( - """ - let %_: () = (); () - """, - relay.Let( - relay.Var("_", relay.TupleType([])), - UNIT, - UNIT - ) - ) - - assert_parses_as( - """ - let %_: (int32,) = (0,); () - """, - relay.Let( - relay.Var("_", relay.TupleType([int32])), - relay.Tuple([relay.const(0)]), - UNIT - ) - ) - - assert_parses_as( - """ - let %_: (int32, int32) = (0, 1); () - """, - relay.Let( - relay.Var("_", relay.TupleType([int32, int32])), - relay.Tuple([relay.const(0), relay.const(1)]), - UNIT - ) - ) - - -def test_adt_defn(): - mod = tvm.IRModule() - - glob_typ_var = relay.GlobalTypeVar("Ayy") - prog = relay.TypeData( - glob_typ_var, - [], - [relay.Constructor("Nil", [], glob_typ_var)]) - mod[glob_typ_var] = prog - assert_parse_module_as( - """ - type Ayy { Nil } - """, - mod - ) - - -def test_empty_adt_defn(): - mod = tvm.IRModule() - - glob_typ_var = relay.GlobalTypeVar("Ayy") - prog = relay.TypeData(glob_typ_var, [], []) - mod[glob_typ_var] = prog - assert_parse_module_as( - """ - type Ayy { } - """, - mod - ) - - -def test_multiple_cons_defn(): - mod = tvm.IRModule() - - list_var = relay.GlobalTypeVar("List") - typ_var = relay.TypeVar("A") - prog = relay.TypeData( - list_var, - [typ_var], - [ - relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var), - relay.Constructor("Nil", [], list_var), - ]) - mod[list_var] = prog - assert_parse_module_as(LIST_DEFN, mod) - - -def test_multiple_type_param_defn(): - glob_typ_var = relay.GlobalTypeVar("Either") - typ_var_a = relay.TypeVar("A") - typ_var_b = relay.TypeVar("B") - prog = relay.TypeData( - glob_typ_var, - [typ_var_a, typ_var_b], - [ - relay.Constructor("Left", [typ_var_a], glob_typ_var), - relay.Constructor("Right", [typ_var_b], glob_typ_var), - ]) - mod = tvm.IRModule() - mod[glob_typ_var] = prog - assert_parse_module_as( - """ - type Either[A, B] { - Left(A), - Right(B), - } - """, - mod - ) - - -def test_match(): - # pair each match keyword with whether it specifies a complete match or not - match_keywords = [("match", True), ("match?", False)] - for (match_keyword, is_complete) in match_keywords: - mod = tvm.IRModule() - - list_var = relay.GlobalTypeVar("List") - typ_var = relay.TypeVar("A") - cons_constructor = relay.Constructor( - "Cons", [typ_var, list_var(typ_var)], list_var) - nil_constructor = relay.Constructor("Nil", [], list_var) - list_def = relay.TypeData( - list_var, - [typ_var], - [cons_constructor, nil_constructor]) - mod[list_var] = list_def - - length_var = relay.GlobalVar("length") - typ_var = relay.TypeVar("A") - input_type = list_var(typ_var) - input_var = relay.Var("xs", input_type) - rest_var = relay.Var("rest") - cons_case = relay.Let( - relay.var("", type_annotation=None), - UNIT, - relay.add(relay.const(1), relay.Call(length_var, [rest_var]))) - body = relay.Match(input_var, - [relay.Clause( - relay.PatternConstructor( - cons_constructor, - [relay.PatternWildcard(), relay.PatternVar(rest_var)]), - cons_case), - relay.Clause( - relay.PatternConstructor(nil_constructor, []), - relay.const(0))], - complete=is_complete - ) - length_func = relay.Function( - [input_var], - body, - int32, - [typ_var] - ) - mod[length_var] = length_func - - assert_parse_module_as( - """ - %s - - def @length[A](%%xs: List[A]) -> int32 { - %s (%%xs) { - Cons(_, %%rest : List[A]) => { - (); - 1 + @length(%%rest) - }, - Nil => 0, - } - } - """ % (LIST_DEFN, match_keyword), - mod - ) - - -def test_adt_cons_expr(): - mod = tvm.IRModule() - - list_var = relay.GlobalTypeVar("List") - typ_var = relay.TypeVar("A") - cons_constructor = relay.Constructor( - "Cons", [typ_var, list_var(typ_var)], list_var) - nil_constructor = relay.Constructor("Nil", [], list_var) - list_def = relay.TypeData( - list_var, - [typ_var], - [cons_constructor, nil_constructor]) - mod[list_var] = list_def - - make_singleton_var = relay.GlobalVar("make_singleton") - input_var = relay.Var("x", int32) - make_singleton_func = relay.Function( - [input_var], - cons_constructor(input_var, nil_constructor()), - list_var(int32) - ) - mod[make_singleton_var] = make_singleton_func - - assert_parse_module_as( - """ - %s - - def @make_singleton(%%x: int32) -> List[int32] { - Cons(%%x, Nil) - } - """ % LIST_DEFN, - mod - ) - - -@raises_parse_error -def test_duplicate_adt_defn(): - parse_module( - """ - %s - - type List[A] { - Cons(A, List[A]), - Nil, - } - """ % LIST_DEFN - ) - - -@raises_parse_error -def test_duplicate_adt_cons(): - parse_text( - """ - type Ayy { Lmao } - type Haha { Lmao } - """ - ) - - -@raises_parse_error -def test_duplicate_adt_cons_defn(): - parse_text( - """ - type Ayy { Lmao } - type Lmao { Ayy } - """ - ) - - -@raises_parse_error -def test_duplicate_global_var(): - parse_text( - """ - def @id[A](%x: A) -> A { x } - def @id[A](%x: A) -> A { x } - """ - ) - - -def test_extern_adt_defn(): - # TODO(weberlo): update this test once extern is implemented - mod = tvm.IRModule() - - extern_var = relay.GlobalTypeVar("T") - typ_var = relay.TypeVar("A") - extern_def = relay.TypeData(extern_var, [typ_var], []) - mod[extern_var] = extern_def - - assert_parse_module_as( - """ - extern type T[A] - """, - mod - ) - -@pytest.mark.skip("not yet tested on parser 2.0") -def test_import_grad(): - mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - -if __name__ == "__main__": - import sys - pytest.main(sys.argv) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 2a88c0c99ae7b..52551bf68e77d 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -17,28 +17,29 @@ import tvm from tvm import te from tvm import relay -import tvm.relay.testing +from tvm.relay import testing import numpy as np from tvm.relay import Expr from tvm.relay.analysis import free_vars -do_print = [False] +DEBUG_PRINT = False -SEMVER = "v0.0.4\n" +SEMVER = "#[version = \"0.0.5\"]\n" -def astext(p, unify_free_vars=False): - txt = p.astext() - if isinstance(p, Expr) and free_vars(p): - return txt - x = relay.fromtext(txt) - if unify_free_vars: - tvm.ir.assert_structural_equal(x, p, map_free_vars=True) +def astext(program, unify_free_vars=False): + text = program.astext() + print(text) + if isinstance(program, Expr): + roundtrip_program = tvm.parser.parse_expr(text) else: - tvm.ir.assert_structural_equal(x, p) - return txt + roundtrip_program = tvm.parser.fromtext(text) + + tvm.ir.assert_structural_equal(roundtrip_program, program, map_free_vars=True) + + return text def show(text): - if do_print[0]: + if DEBUG_PRINT: print("---------------------------") print(text) @@ -135,55 +136,55 @@ def test_variable_name(): def test_mlp(): - net, params = tvm.relay.testing.mlp.get_workload(batch_size=1) + net, _ = tvm.relay.testing.mlp.get_workload(batch_size=1) astext(net) def test_resnet(): - net, params = tvm.relay.testing.resnet.get_workload(batch_size=1) + net, _ = tvm.relay.testing.resnet.get_workload(batch_size=1) astext(net) def test_mobilenet(): - net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1) + net, _ = tvm.relay.testing.mobilenet.get_workload(batch_size=1) astext(net) def test_dqn(): - net, params = tvm.relay.testing.dqn.get_workload(batch_size=1) + net, _ = tvm.relay.testing.dqn.get_workload(batch_size=1) astext(net) def test_dcgan(): - net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1) + net, _ = tvm.relay.testing.dcgan.get_workload(batch_size=1) astext(net) def test_lstm(): - net, params = tvm.relay.testing.lstm.get_workload(1, 1) + net, _ = tvm.relay.testing.lstm.get_workload(1, 1) astext(net) - net, params = tvm.relay.testing.lstm.get_workload(4, 4) + net, _ = tvm.relay.testing.lstm.get_workload(4, 4) astext(net) def test_inception_v3(): - net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1) + net, _ = tvm.relay.testing.inception_v3.get_workload(batch_size=1) astext(net) def test_squeezenet(): for version in ['1.0', '1.1']: - net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) + net, _ = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) astext(net) def test_vgg(): - net, params = tvm.relay.testing.vgg.get_workload(batch_size=1) + net, _ = tvm.relay.testing.vgg.get_workload(batch_size=1) astext(net) def test_densenet(): - net, params = tvm.relay.testing.densenet.get_workload(batch_size=1) + net, _ = tvm.relay.testing.densenet.get_workload(batch_size=1) astext(net) @@ -232,7 +233,7 @@ def @main[A]() -> fn (A, List[A]) -> List[A] { Cons } """ - mod = relay.fromtext(SEMVER + type_def_str + main_def_str) + mod = tvm.parser.parse(SEMVER + type_def_str + main_def_str) mod_str = str(mod) # ensure constructors are printed correctly in type definitions (with their # signature) and as exprs (without their signature) @@ -250,25 +251,5 @@ def test_null_attribute(): if __name__ == "__main__": - do_print[0] = True - test_lstm() - test_zeros() - test_meta_data() - test_let_inlining() - test_resnet() - test_mobilenet() - test_mlp() - test_dqn() - test_dcgan() - test_squeezenet() - test_inception_v3() - test_vgg() - test_densenet() - test_func() - test_env() - test_call_attrs() - test_let_if_scope() - test_variable_name() - test_call_node_order() - test_unapplied_constructor() - test_null_attribute() + import sys + pytext.argv(sys.argv) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index f65407acbcc9a..c0a990ba9d2e8 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -59,9 +59,9 @@ def test_checkpoint_alpha_equal(): mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] - df_parsed = relay.parser.fromtext( + df_parsed = tvm.parser.parse_expr( """ - v0.0.4 + #[version = "0.0.5"] fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> (Tensor[(1), float32], @@ -115,9 +115,9 @@ def test_checkpoint_alpha_equal_tuple(): mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] - df_parsed = relay.parser.fromtext( + df_parsed = tvm.parser.parse_expr( """ - v0.0.4 + #[version = "0.0.5"] fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> ((Tensor[(1), float32], Tensor[(1), float32]), diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py index e0a189b5c2eea..05c5f0328e22b 100644 --- a/tests/python/relay/test_pass_eta_expand.py +++ b/tests/python/relay/test_pass_eta_expand.py @@ -24,24 +24,24 @@ import tvm.relay.transform as _transform def test_eta_expand_global_var(): - mod = relay.fromtext(r""" - v0.0.4 + mod = tvm.parser.fromtext(r""" + #[version = "0.0.5"] def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } - def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { + def @main() -> fn(Tensor[(), int32]) -> Tensor[(), int32] { @aux } """) seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) - expected = relay.fromtext(r""" - v0.0.4 + expected = tvm.parser.fromtext(r""" + #[version = "0.0.5"] def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } - def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { + def @main() -> fn(Tensor[(), int32]) -> Tensor[(), int32] { fn (%x: Tensor[(), int32]) -> Tensor[(), int32] { @aux(%x) } @@ -52,26 +52,26 @@ def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { def test_eta_expand_constructor(): - mod = relay.fromtext(r""" - v0.0.4 + mod = tvm.parser.fromtext(r""" + #[version = "0.0.5"] type List[A] { Cons(A, List[A]), Nil, } - def @main[A]() -> (fn(A, List[A]) -> List[A]) { + def @main[A]() -> fn(A, List[A]) -> List[A] { Cons } """) seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) - expected = relay.fromtext(r""" - v0.0.4 + expected = tvm.parser.fromtext(r""" + #[version = "0.0.5"] type List[A] { Cons(A, List[A]), Nil, } - def @main[A]() -> (fn(A, List[A]) -> List[A]) { + def @main[A]() -> fn(A, List[A]) -> List[A] { fn [A](%x: A, %xs: List[A]) -> List[A] { Cons(%x, %xs) } diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index 42344bccabaaa..07193e104a7c8 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -279,7 +279,7 @@ def test_tuple_match(): def test_inf_loop_case(): code = """ -v0.0.4 +#[version = "0.0.5"] type Arith[A] { Zero, Const(A), @@ -294,7 +294,7 @@ def @shallow_opt[A](%a: Arith[A]) -> Arith[A] { } } """ - relay.fromtext(code) + tvm.parser.fromtext(code) # fromtext parse the module, then checked it (which include strictness checking). if __name__ == "__main__":