From 8c72766a85edd33bf0acc70895619e66f6d52ed6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 16:27:52 -0700 Subject: [PATCH 01/88] Add stripped down version of expr.h and type.h This commit adds a simplified version of type.h and expr.h from the previous Relay version. We implement the basic data types and the associated machinery for exporting these to Python, as well as tests that they can be constructed, all fields are live, and can be printed using `str`. --- CMakeLists.txt | 6 + include/tvm/relay/base.h | 154 ++++++++++++ include/tvm/relay/expr.h | 361 ++++++++++++++++++++++++++++ include/tvm/relay/type.h | 243 +++++++++++++++++++ python/tvm/relay/__init__.py | 12 + python/tvm/relay/_make.py | 9 + python/tvm/relay/_make.pyi | 91 +++++++ python/tvm/relay/base.py | 27 +++ python/tvm/relay/expr.py | 69 ++++++ python/tvm/relay/make.py | 20 ++ python/tvm/relay/type.py | 51 ++++ src/relay/base.cc | 40 +++ src/relay/expr.cc | 181 ++++++++++++++ src/relay/type.cc | 100 ++++++++ tests/python/relay/test_ir_nodes.py | 154 ++++++++++++ 15 files changed, 1518 insertions(+) create mode 100644 include/tvm/relay/base.h create mode 100644 include/tvm/relay/expr.h create mode 100644 include/tvm/relay/type.h create mode 100644 python/tvm/relay/__init__.py create mode 100644 python/tvm/relay/_make.py create mode 100644 python/tvm/relay/_make.pyi create mode 100644 python/tvm/relay/base.py create mode 100644 python/tvm/relay/expr.py create mode 100644 python/tvm/relay/make.py create mode 100644 python/tvm/relay/type.py create mode 100644 src/relay/base.cc create mode 100644 src/relay/expr.cc create mode 100644 src/relay/type.cc create mode 100644 tests/python/relay/test_ir_nodes.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 572f4aef1432..65a7d9e36e2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -104,6 +104,12 @@ file(GLOB COMPILER_SRCS src/schedule/*.cc ) +file(GLOB_RECURSE RELAY_SRCS + src/relay/*.cc + ) +list(APPEND COMPILER_SRCS ${RELAY_SRCS}) + + if(NOT MSVC) file(GLOB COMPILER_VERILOG_SRCS src/codegen/verilog/*.cc) list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h new file mode 100644 index 000000000000..3b31aae52617 --- /dev/null +++ b/include/tvm/relay/base.h @@ -0,0 +1,154 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/base.h + * \brief Base data structure for relay. + */ +#ifndef TVM_RELAY_BASE_H_ +#define TVM_RELAY_BASE_H_ + +#include +#include +#include +#include + +namespace tvm { +/*! + * \brief Relay: high level functional IR + */ +namespace relay { +/*! + * \brief we always used NodeRef for referencing nodes. + * + * By default, NodePtr is a std::shared_ptr of node + */ +using NodeRef = tvm::NodeRef; + +/*! + * \brief Content data type. + */ +using DataType = ::tvm::Type; + +/*! + * \brief Symbolic expression for tensor shape. + */ +using ShapeExpr = ::tvm::Expr; + +/*! + * \brief Hash function for nodes. + * e.g. std::unordered_map + */ +using NodeHash = ::tvm::NodeHash; +/*! + * \brief Equality check function for nodes. + */ +using NodeEqual = ::tvm::NodeEqual; + +/*! + * \brief Macro to make it easy to define node ref type given node + * \param TypeName The name of the reference type. + * \param NodeName The internal contrainer name. + * \param NodeRefBase The base type. + */ +#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ + class TypeName : public NodeRefBase { \ + public: \ + TypeName() {} \ + explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + using ContainerType = NodeName; \ + }; + + +/*! + * \brief The source name in the Span + * \sa SourceNameNode, Span + */ +class SourceName; +/*! + * \brief The source name in the Span + */ +class SourceNameNode : public Node { + public: + /*! \brief The source name */ + std::string name; + // override attr visitor + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + } + + TVM_DLL static SourceName make(std::string name); + + static constexpr const char* _type_key = "relay.SourceName"; + TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); +}; + +RELAY_DEFINE_NODE_REF(SourceName, SourceNameNode, NodeRef); + +/*! + * \brief Span information for debugging purposes + */ +class Span; +/*! + * \brief Stores locations in frontend source that generated a node. + * + */ +class SpanNode : public Node { + public: + /*! \brief The source name */ + SourceName source; + /*! \brief Line number */ + int lineno; + /*! \brief column offset */ + int col_offset; + // override attr visitor + void VisitAttrs(AttrVisitor* v) final { + v->Visit("source", &source); + v->Visit("lineno", &lineno); + v->Visit("col_offset", &col_offset); + } + + TVM_DLL static Span make(SourceName source, int lineno, int col_offset); + + static constexpr const char* _type_key = "relay.Span"; + TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node); +}; + +RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef); + +/*! + * \brief This is the base node container of all relay structures. + */ +class RelayNode : public Node { + public: + /*! \brief The debug information, can be null, check with span.defined() */ + mutable Span span; + + static constexpr const char* _type_key = "relay.Node"; + TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); +}; + +/*! + * \brief Get a reference type from a Node ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam NodeType The node type + * \return The corresponding RefType + */ +template +RefType GetRef(const NodeType* ptr) { + static_assert(std::is_same::value, + "Can only cast to the ref of same container type"); + return RefType(const_cast(ptr)->shared_from_this()); +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BASE_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h new file mode 100644 index 000000000000..b830c7ce04ef --- /dev/null +++ b/include/tvm/relay/expr.h @@ -0,0 +1,361 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/expr.h + * \brief Relay expression IR Node. + */ +#ifndef TVM_RELAY_EXPR_H_ +#define TVM_RELAY_EXPR_H_ + +#include +#include +#include +#include +#include "./base.h" +#include "./type.h" + +namespace tvm { +namespace relay { +/*! + * \brief Relay expression. + */ +class Expr; +/*! + * \brief Base type of the Relay type hiearchy. + */ +class ExprNode : public RelayNode { + public: + /*! + * \brief Stores the result of type inference(type checking). + * + * \note This can be undefined before type inference. + * this value is discarded during serialization. + */ + Type checked_type_ = Type(nullptr); + /*! + * \return The checked_type + */ + const Type& checked_type() const { + CHECK(checked_type_.defined()) << "internal error: the type checker has " + "not populated the checked_type " + << "field for this node"; + return this->checked_type_; + } + + static constexpr const char* _type_key = "relay.Expr"; + TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); +}; + +RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); + +/*! + * \brief Constant tensor, backed by an NDArray on cpu(0). + * + * \note scalar constants are represented by rank-0 const tensor. + * Constant folding are handled uniformly via Tensor types. + */ +class Constant; +/*! + * \brief Constant tensor type. + */ +class ConstantNode : public ExprNode { + public: + /*! \brief The data of the tensor */ + runtime::NDArray data; + + // TODO(tqchen) add the function after we get TensorType constructor + // TODO(tqchen) create simple TensorType constructor for concrete types. + /*! \return The corresponding tensor type of the data */ + TensorType tensor_type() const; + + /*! \return whether it is scalar(rank-0 tensor) */ + bool is_scalar() const { return data->ndim == 0; } + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", &data); + v->Visit("span", &span); + } + + TVM_DLL static Constant make(runtime::NDArray data); + + static constexpr const char* _type_key = "relay.Constant"; + TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr); + +/*! \brief Tuple of multiple Exprs */ +class Tuple; +/*! \brief Tuple container */ +class TupleNode : public ExprNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("fields", &fields); + v->Visit("span", &span); + } + + TVM_DLL static Tuple make(tvm::Array fields); + + static constexpr const char* _type_key = "relay.Tuple"; + TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); + +/*! + * \brief Local variables used in the let expression. + * This is similar to Var that is being used in the low level tensor expression. + * + * \note Each LocalVar is bind only once and is immutable/ + */ +class LocalVar; +/*! \brief Container for LocalVar */ +class LocalVarNode : public ExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint. */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + } + + TVM_DLL static LocalVar make(std::string name_hint); + + static constexpr const char* _type_key = "relay.LocalVar"; + TVM_DECLARE_NODE_TYPE_INFO(LocalVarNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(LocalVar, LocalVarNode, Expr); + +/*! + * \brief Global variable that leaves in the top-level environment. + * This is used to enable recursive calls between function. + * + * \note GlobalVar can only corresponds to functions. + */ +class GlobalVar; +/*! \brief A GlobalId from the node's current type to target type. */ +class GlobalVarNode : public ExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint. */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + } + + TVM_DLL static GlobalVar make(std::string name_hint); + + static constexpr const char* _type_key = "relay.GlobalVar"; + TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr); + +/*! + * \brief Function parameter declaration. + */ +class Param; +/*! \brief A parameter. */ +class ParamNode : public ExprNode { + public: + /*! \brief The variable */ + LocalVar var; + /*! \brief The type of the parameter */ + Type type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("type", &type); + v->Visit("span", &span); + } + + TVM_DLL static Param make(LocalVar var, Type type); + + static constexpr const char* _type_key = "relay.Param"; + TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr); + +/*! + * \brief Function (subgraph in computational graph) + */ +class Function; +/*! \brief Function container */ +class FunctionNode : public ExprNode { + public: + /*! \brief Function parameters */ + tvm::Array params; + /*! \brief User annotated return type of the function. */ + Type ret_type; + /*! + * \brief + * The expression which represents the computation of the function, + * the expression may reference the parameters, and the type of it + * or sub-expressions may reference the type variables. + */ + Expr body; + /*! + * \brief Type parameters of the function. + * Enables the function to vary its type based on these. + * This corresponds to template paramaters in c++'s terminology. + * + * \note This can be usually empty for non-polymorphic functions. + */ + tvm::Array type_params; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("params", ¶ms); + v->Visit("ret_type", &ret_type); + v->Visit("body", &body); + v->Visit("type_params", &type_params); + v->Visit("span", &span); + } + + TVM_DLL static Function make(tvm::Array params, Type ret_type, + Expr body, tvm::Array ty_params); + + static constexpr const char* _type_key = "relay.Function"; + TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); + +// TODO(tqchen) change Expr to Attr after we introduce Attr system. +using Attrs = tvm::Map; + +/*! + * \brief Call corresponds to operator invocation. + * Corresponds to the operator in computational graph terminology. + */ +class Call; +/*! \brief Call container. */ +class CallNode : public ExprNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be relay::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, LocalVar). + */ + Expr op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The type arguments passed to polymorphic(template) function. + * + * This is the advance feature that is only used when the function is + * polymorphic. It is safe to be ignored in most cases. For example, in the + * following code, the type_args of addone call is [int]. + * + * \code + * + * template + * T addone(T a) { return a + 1; } + * + * void main() { + * int x = addone(10); + * } + * + * \endcode + */ + tvm::Array type_args; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("type_args", &type_args); + v->Visit("span", &span); + } + + TVM_DLL static Call make(Expr op, Array args, Attrs attrs, + Array ty_args); + + static constexpr const char* _type_key = "relay.Call"; + TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Call, CallNode, Expr); + +/*! + * \brief Let binding that binds a local var and optionally a type annotation. + * + * \note Let is useful to transform the program to be A-normal form. + * where each of the expression corresponds to a let binding. + * + * For developers who are familar with the computational graph. + * Each of the let can be viewed as a operator node in the computational graph. + * Traversing the list of let bindings is similar to running + * PostDFS-order(topo-order) traversal on the computational graph. + */ +class Let; +/*! \brief A binding of a sub-network. */ +class LetNode : public ExprNode { + public: + /*! \brief The variable we bind to */ + LocalVar var; + /*! \brief The value we bind var to */ + Expr value; + /*! \brief The body of the let binding */ + Expr body; + /*! \brief type annotation of value, this can be null */ + Type value_type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("body", &body); + v->Visit("value_type", &value_type); + v->Visit("span", &span); + } + + TVM_DLL static Let make(LocalVar var, Expr value, Expr body, Type value_type); + + static constexpr const char* _type_key = "relay.Let"; + TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); + +/*! + * \brief Condition expression + */ +class If; +/*! \brief container of If */ +class IfNode : public ExprNode { + public: + /*! \brief The condition */ + Expr cond; + /*! \brief The value to take when condition is true */ + Expr true_value; + /*! \brief The value to take when condition is false */ + Expr false_value; + + IfNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("cond", &cond); + v->Visit("true_value", &true_value); + v->Visit("false_value", &false_value); + v->Visit("span", &span); + } + + TVM_DLL static If make(Expr cond, Expr true_value, Expr false_value); + + static constexpr const char* _type_key = "relay.If"; + TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(If, IfNode, Expr); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h new file mode 100644 index 000000000000..4c6995646114 --- /dev/null +++ b/include/tvm/relay/type.h @@ -0,0 +1,243 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/type.h + * \brief Relay typed AST nodes. + */ +#ifndef TVM_RELAY_TYPE_H_ +#define TVM_RELAY_TYPE_H_ + +#include +#include +#include +#include + +#include "./base.h" + +namespace tvm { +namespace relay { + +/*! \brief Base type of the Relay type hiearchy. */ +class TypeNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Type"; + TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node); +}; + +/*! + * \brief Type is the base type of relay type hiearchy. + * + * Relay's type system contains following two key concepts: + * + * - TensorType: type of certain Tensor values in the expression. + * - FunctionType: the type of the function. + * + * There are also advanced types to support generic(polymorphic types), + * which can be ignored when first reading the code base. + */ +class Type : public NodeRef { + public: + Type() {} + explicit Type(std::shared_ptr p) : NodeRef(p) {} + + using ContainerType = TypeNode; +}; + +/*! + * \brief Base of all Tensor types + * This container can hold TensorType or GenericTensorType. + */ +class BaseTensorTypeNode : public TypeNode { + public: + static constexpr const char* _type_key = "relay.BaseTensorType"; + TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type); + +/*! + * \brief This is the most commonly used type in relay. + * TensorType have a fixed dimension, data type. + * + * The elements of shape can be either IntImm(constant integer), + * or any symbolic integer expression. + * The symbolic integer allows generic shape inference in certain cases. + * \sa TensorTypeNode The container class of TensorType. + */ +class TensorType; +/*! \brief TensorType container node */ +class TensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The shape of the tensor, + * represented by ShapeExpr(tvm::Expr). + */ + Array shape; + /*! \brief The content data type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + TVM_DLL static TensorType make(Array shape, DataType dtype); + + static constexpr const char* _type_key = "relay.TensorType"; + TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); +}; + +RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); + +/*! + * \brief Type parameter in the function. + * This can be viewed as template parameter in c++ template function. + * + * For example, in the following pesudo code, + * the TypeParam of f is TypeParam(kind=kShapeVar, var=n). + * This function can take in a Tensor with shape=(3, 3) and + * returns a Tensor with shape=(9,) + * + * \code + * + * template + * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] + * + * \endcode + * \sa TypeParamNode The actual container class of TypeParam + */ +class TypeParam; +/*! \brief TypeParam container node */ +class TypeParamNode : public TypeNode { + public: + /*! \brief possible kinds of TypeParam */ + enum Kind : int { + /*! \brief template variable in shape expression */ + kShapeVar = 0 + }; + /*! + * \brief The variable + * The variable itself is only meaningful when + * kind is ShapeVar, otherwise, we can only use the name. + */ + tvm::Var var; + /*! \brief The kind of type parameter */ + Kind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("kind", &kind); + v->Visit("span", &span); + } + + TVM_DLL static TypeParam make(std::string name, Kind kind); + + static constexpr const char* _type_key = "relay.TypeParam"; + TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); + +/*! + * \brief Potential Constraints in the type. + * \note This is reserved for future use. + */ +class TypeConstraint; +/*! \brief TypeConstraint container node. */ +class TypeConstraintNode : public Node { + public: + static constexpr const char* _type_key = "relay.TypeConstraint"; + TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, Node); +}; + +RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, NodeRef); + +class FuncType; +/*! + * \brief Function type in Relay. + * + * Relay support polymorphic function type. + * This can be roughly viewed as template function in C++. + * + * \sa TypeParam, TypeConstraint + */ +class FuncTypeNode : public TypeNode { + public: + /*! \brief type type of arguments */ + tvm::Array arg_types; + /*! \brief The type of return value. */ + Type ret_type; + // The following fields are used in polymorphic(template) functions + // For normal functions, the following two fields will be empty. + /*! \brief The type parameters of the function */ + tvm::Array type_params; + /*! + * \brief potential constraint the type need to obey + * \note this field is reserved for futher purposes. + */ + tvm::Array type_constraints; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("arg_types", &arg_types); + v->Visit("ret_type", &ret_type); + v->Visit("type_params", &type_params); + v->Visit("type_constraints", &type_constraints); + v->Visit("span", &span); + } + + TVM_DLL static FuncType make(tvm::Array arg_types, Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints); + + static constexpr const char* _type_key = "relay.FuncType"; + TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); + +/*! + * \brief Opaque type inference function. + */ +class TypeFunction; +/*! + * \brief TypeFunction container. + * \note This node is not directly serializable. + * The type function need to be lookedup in the environment. + */ +class TypeFunctionNode : public RelayNode { + public: + /*! \brief The name of the function */ + std::string name; + /*! \brief Number of input type arguments, can be -1, which means VarArgs */ + int num_args; + /*! + * \brief The type function, + * this is not directly serializable, + * need to be looked-up in the environment. + */ + mutable std::function& arg_types)> func_; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("num_args", &num_args); + } + + TVM_DLL static TypeFunction make(std::string name, int num_args); + + static constexpr const char* _type_key = "relay.TypeFunction"; + TVM_DECLARE_NODE_TYPE_INFO(TypeFunctionNode, RelayNode); +}; + +RELAY_DEFINE_NODE_REF(TypeFunction, TypeFunctionNode, NodeRef); + +// The following fields contains advanced typing +// Only keep the class name and reserved for future usage. +class GenericTensorType; +// stores a DataType. +class GenericDataType; +// stores a DataType. +class GenericShape; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPE_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py new file mode 100644 index 000000000000..c90875db4178 --- /dev/null +++ b/python/tvm/relay/__init__.py @@ -0,0 +1,12 @@ +"""Relay namespace.""" +from . import base +from . import type as tpe +from . import make + +# Type +Type = tpe.Type +TensorType = tpe.TensorType +Kind = tpe.Kind +TypeParam = tpe.TypeParam +TypeConstraint = tpe.TypeConstraint +FuncType = tpe.FuncType diff --git a/python/tvm/relay/_make.py b/python/tvm/relay/_make.py new file mode 100644 index 000000000000..20a582e76d6a --- /dev/null +++ b/python/tvm/relay/_make.py @@ -0,0 +1,9 @@ +""" +The constructors for all Relay AST nodes exposed from C++. + +This module includes MyPy type signatures for all of the +exposed modules. +""" +from .._ffi.function import _init_api + +_init_api("relay._make", __name__) diff --git a/python/tvm/relay/_make.pyi b/python/tvm/relay/_make.pyi new file mode 100644 index 000000000000..d94857916319 --- /dev/null +++ b/python/tvm/relay/_make.pyi @@ -0,0 +1,91 @@ +# from typing import Dict, List, Any, Callable, TypeVar as PyTypeVar +# import nnvm.relay.ir as ir +# import nnvm.relay.env as env +# import ctypes + +# # Environment +# def Environment(items: Dict[ir.GlobalId, ir.Item]) -> env.Environment: ... + +# # Items TODO(@jroesch) Correct Anys to the right type. +# def Operator(id: ir.OperatorId, tvm_name: str, ty: ir.Type, compiler: Any, fwd_mode: Any, rev_mode: Any) -> ir.Operator: ... +# def Defn(id: ir.GlobalId, ty: ir.Type, body: ir.Function) -> ir.Defn: ... + +# # Types +# def IntType(bits: int, lanes: int) -> ir.Type: ... +# def UIntType(bits: int, lanes: int) -> ir.Type: ... +# def FloatType(bits: int, lanes: int) -> ir.Type: ... +# def BoolType(lanes: int) -> ir.Type: ... +# def TupleType(fields: List[ir.Type]) -> ir.Type: ... +# def TensorType(dtype: ir.Type, shape: ir.Type) -> ir.Type: ... +# def TypeParam(name: str, kind: ir.Kind) -> ir.Type: ... +# def TypeQuantifier(id: ir.TypeId, body: ir.Type) -> ir.Type: ... +# def TypeArrow(left: ir.Type, right: ir.Type) -> ir.Type: ... +# def TypeVar(kind: ir.Kind) -> ir.Type: ... +# def PlaceholderType() -> ir.Type: ... +# def ShapeSeq(shapes: List[ir.Type]) -> ir.ShapeSeq: ... +# def ShapeSingleton(value: int) -> ir.ShapeSingleton: ... +# def ShapeAttr(id: ir.StringLit) -> ir.ShapeAttr: ... +# def ShapeProjection(shape: ir.Type, value: int) -> ir.ShapeProjection: ... +# def ShapeBinaryOp(op: ir.ShapeOp, left: ir.Type, right: ir.Type) -> ir.ShapeBinaryOp: ... +# def ShapeBroadcast(left: ir.Type, right: ir.Type) -> ir.ShapeBroadcast: ... +# def ShapeExtension(name: str, eval: Any) -> ir.ShapeExtension: ... +# def TypeCall(func: ir.Type, args: List[ir.Type]) -> ir.TypeCall: ... +# def RefType(data_type: ir.Type) -> ir.RefType: ... + +# # Expressions +# def Param(id: ir.LocalId, type: ir.Type) -> ir.Param: ... +# def Function(ty_params: List[ir.TypeId], params: List[ir.Param], ret_type: ir.Type, body: ir.Expr) -> ir.Function: ... +# def LocalId(name: str) -> ir.Expr: ... +# def GlobalId(name: str) -> ir.Expr: ... +# def OperatorId(name: str) -> ir.Expr: ... +# def Let(id: ir.LocalId, ty: ir.Type, value: ir.Expr, body: ir.Expr) -> ir.Expr: ... +# def IntLit(value: int) -> ir.IntLit: ... +# def FloatLit(value: float) -> ir.FloatLit: ... +# def TensorLit(value: List[ir.Expr]) -> ir.TensorLit: ... +# def Tuple(fields: List[ir.Expr]) -> ir.Expr: ... +# def BoolLit(value: bool) -> ir.BoolLit: ... +# def StringLit(value: str) -> ir.StringLit: ... +# def Attributes(attrs: Dict[str, ir.Expr]) -> ir.Attributes: ... +# def Call(func: ir.Expr, args: List[ir.Expr], attrs: ir.Attributes) -> ir.Call: ... +# def UnaryOp(op: ir.UOp, arg: ir.Expr) -> ir.Expr: ... +# def BinaryOp(op: ir.BOp, left: ir.Expr, right: ir.Expr) -> ir.Expr: ... +# def Projection(tuple: ir.Expr, field : int) -> ir.Expr: ... +# def Gradient(node: ir.Expr) -> ir.Expr: ... +# def Cast(target: ir.Type, node: ir.Expr) -> ir.Expr: ... +# def Debug(node: ir.Expr) -> ir.Expr: ... +# def Zero(type: ir.Type) -> ir.Expr: ... +# def If(guard: ir.Expr, true_branch: ir.Expr, false_branch: ir.Expr) -> ir.Expr: ... +# def Ref(value: ir.Expr) -> ir.Expr: ... +# def ReadRef(ref: ir.Expr) -> ir.Expr: ... +# def WriteRef(ref: ir.Expr, value: ir.Expr) -> ir.Expr: ... + +# # Values +# def IntValue(value: int) -> ir.TensorValue: ... +# def FloatValue(value: float) -> ir.TensorValue: ... +# def BoolValue(value: bool) -> ir.TensorValue: ... +# def TensorValue(handle: ctypes.c_void_p) -> ir.TensorValue: ... +# def Closure(env: Dict[ir.LocalId, ir.Value], fn: ir.Function) -> ir.Closure: ... + +# # Error Reporting +# def Span(file_id: ir.FileId, lineno: int, col_offset: int) -> ir.NodeBase: ... +# def FileId(file_id: int) -> ir.FileId: ... + +# # Utils +# def _alpha_eq(e1: ir.Expr, e2: ir.Expr) -> bool: ... +# def _type_alpha_eq(e1: ir.Type, e2: ir.Type) -> bool: ... +# def _expr_set_span(e: ir.Expr, sp: ir.Span) -> None: ... +# def _type_set_span(t: ir.Type, sp: ir.Span) -> None: ... +# def _item_set_span(t: ir.Item, sp: ir.Span) -> None: ... +# def Node_hash(n: ir.Node) -> int: ... +# def Operator_is_generic(op: ir.Operator) -> bool: ... + +# # FIXME +# def UnionFind() -> Any: ... +# def TypeUnifier() -> Any: ... + +# T = PyTypeVar('T') +# U = PyTypeVar('U') +# PassFunc = Callable[[env.Environment], Callable[[T], U]] + +# # Passes +# def ItemPass(name: str, pass_func: PassFunc[ir.Item, ir.Item]) -> ir.ItemPass: ... diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py new file mode 100644 index 000000000000..687ba53ac005 --- /dev/null +++ b/python/tvm/relay/base.py @@ -0,0 +1,27 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck +"""The base node types for the Relay language.""" +from __future__ import absolute_import as _abs +from typing import Union +from .._ffi.node import NodeBase, register_node as _register_tvm_node + +NodeBase = NodeBase + +def register_relay_node(type_key=None): + """register relay node type + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if not isinstance(type_key, str): + return _register_tvm_node( + "relay." + type_key.__name__)(type_key) + return _register_tvm_node(type_key) + + +@register_relay_node +class Span(NodeBase): + source: "FileSource" + lineno: int + col_offset: int diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py new file mode 100644 index 000000000000..dea3a99f5f09 --- /dev/null +++ b/python/tvm/relay/expr.py @@ -0,0 +1,69 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The expression nodes of Relay.""" +import tvm +from typing import Tuple as PyTuple, List +from enum import IntEnum +from .base import Span, NodeBase, register_relay_node +from .type import Type, TypeParam +from tvm import expr + +class Expr(NodeBase): + """The base type for all Relay exprressions.""" + pass + +@register_relay_node +class Constant(Expr): + """A constant tensor in Relay, see tvm/relay/type.h for more details. + """ + data: tvm.nd.NDArray + +@register_relay_node +class Tuple(Expr): + """A hetereogenous sequence of values. + see tvm/relay/type.h for more details. + """ + fields: List[Expr] + +@register_relay_node +class LocalVar(Expr): + """A local variable in Relay.""" + name_hint: str + +@register_relay_node +class GlobalVar(Expr): + """A global variable in Relay.""" + name_hint: str + +@register_relay_node +class Param(Expr): + """A function type in Relay, see tvm/relay/type.h for more details. + """ + var: LocalVar + type: Type + +@register_relay_node +class Function(Expr): + type_params: List[TypeParam] + params: List[Param] + ret_type: Type + body: Expr + +class Call(Expr): + op: Expr + args: List[Expr] + # todo(@jroesch): add attrs + +@register_relay_node +class Let(Expr): + var: LocalVar + value: Expr + body: Expr + value_type: Type # should be type nanotation + +@register_relay_node +class If(Expr): + cond: Expr + true_value: Expr + false_value: Expr + span: Span + diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py new file mode 100644 index 000000000000..14d9ac040dc9 --- /dev/null +++ b/python/tvm/relay/make.py @@ -0,0 +1,20 @@ +from . import _make + +# Base Constructors +Span = _make.Span + +# Type Constructors +TensorType = _make.TensorType +TypeParam = _make.TypeParam +FuncType = _make.FuncType + +# Expr Constructors +Constant = _make.Constant +Tuple = _make.Tuple +LocalVar = _make.LocalVar +GlobalVar = _make.GlobalVar +Param = _make.Param +Function = _make.Function +Call = _make.Call +Let = _make.Let +If = _make.If diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py new file mode 100644 index 000000000000..c92f0d756587 --- /dev/null +++ b/python/tvm/relay/type.py @@ -0,0 +1,51 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The type nodes of the Relay language.""" +from typing import Tuple, List +from enum import IntEnum +from .base import Span, NodeBase, register_relay_node +from tvm import expr + +class Type(NodeBase): + """The base type for all Relay types.""" + pass + +@register_relay_node +class TensorType(Type): + """A concrete TensorType in Relay, see tvm/relay/type.h for more details. + """ + dtype: str + shape: List[expr.Expr] + span: Span + +class Kind(IntEnum): + """The kind of a type parameter, represents a variable shape, + base type, type, or dimension. + """ + Shape = 0 + BaseType = 1 + Type = 2 + Elem = 3 + +@register_relay_node +class TypeParam(Type): + """A type parameter used for generic types in Relay, + see tvm/relay/type.h for more details. + """ + var: expr.Var + kind: Kind + span: Span + +@register_relay_node +class TypeConstraint(Type): + """Abstract class representing a type constraint.""" + pass + +@register_relay_node +class FuncType(Type): + """A function type in Relay, see tvm/relay/type.h for more details. + """ + type_params: List[TypeParam] + type_constraints: List[TypeConstraint] + arg_types: List[Type] + ret_type: Type + span: Span diff --git a/src/relay/base.cc b/src/relay/base.cc new file mode 100644 index 000000000000..5fdf96ded224 --- /dev/null +++ b/src/relay/base.cc @@ -0,0 +1,40 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file base.cc + * \brief The core base types for Relay. + */ +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +Span SpanNode::make(SourceName source, int lineno, int col_offset) { + std::shared_ptr n = std::make_shared(); + n->source = std::move(source); + n->lineno = lineno; + n->col_offset = col_offset; + return Span(n); +} + +TVM_REGISTER_API("relay._make.Span") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SpanNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { + p->stream << node->name; + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const SpanNode *node, tvm::IRPrinter *p) { + p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", " + << node->col_offset << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/expr.cc b/src/relay/expr.cc new file mode 100644 index 000000000000..38df81940e48 --- /dev/null +++ b/src/relay/expr.cc @@ -0,0 +1,181 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr.cc + * \brief The expression AST nodes of Relay. + */ +#include "tvm/relay/expr.h" +#include "tvm/ir_functor.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +Constant ConstantNode::make(runtime::NDArray data) { + std::shared_ptr n = std::make_shared(); + n->data = std::move(data); + return Constant(n); +} + +TVM_REGISTER_API("relay._make.Constant") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ConstantNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const ConstantNode *node, + tvm::IRPrinter *p) { + p->stream << "ConstantNode(TODO)"; + }); + +Tuple TupleNode::make(tvm::Array fields) { + std::shared_ptr n = std::make_shared(); + n->fields = std::move(fields); + return Tuple(n); +} + +TVM_REGISTER_API("relay._make.Tuple") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TupleNode *node, tvm::IRPrinter *p) { + p->stream << "TupleNode(" << node->fields << ")"; + }); + +LocalVar LocalVarNode::make(std::string name_hint) { + std::shared_ptr n = std::make_shared(); + n->name_hint = std::move(name_hint); + return LocalVar(n); +} + +TVM_REGISTER_API("relay._make.LocalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LocalVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const LocalVarNode *node, + tvm::IRPrinter *p) { + p->stream << "LocalVarNode(" << node->name_hint << ")"; + }); + +GlobalVar GlobalVarNode::make(std::string name_hint) { + std::shared_ptr n = std::make_shared(); + n->name_hint = std::move(name_hint); + return GlobalVar(n); +} + +TVM_REGISTER_API("relay._make.GlobalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = GlobalVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const GlobalVarNode *node, + tvm::IRPrinter *p) { + p->stream << "GlobalVarNode(" << node->name_hint << ")"; + }); + +Param ParamNode::make(LocalVar var, Type type) { + std::shared_ptr n = std::make_shared(); + n->var = std::move(var); + n->type = std::move(type); + return Param(n); +} + +TVM_REGISTER_API("relay._make.Param") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ParamNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const ParamNode *node, tvm::IRPrinter *p) { + p->stream << "ParamNode(" << node->var << ", " << node->type << ")"; + }); + +Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body, + tvm::Array type_params) { + std::shared_ptr n = std::make_shared(); + n->params = std::move(params); + n->ret_type = std::move(ret_type); + n->body = std::move(body); + n->type_params = std::move(type_params); + return Function(n); +} + +TVM_REGISTER_API("relay._make.Function") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const FunctionNode *node, + tvm::IRPrinter *p) { + p->stream << "FunctionNode(TODO)"; + }); + +Call CallNode::make(Expr op, Array args, Attrs attrs, + Array type_args) { + std::shared_ptr n = std::make_shared(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->type_args = std::move(type_args); + return Call(n); +} + +TVM_REGISTER_API("relay._make.Call") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CallNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const CallNode *node, tvm::IRPrinter *p) { + p->stream << "CallNode(" << node->op << ", " << node->args << ", " + << node->attrs << ", " << node->type_args << ")"; + }); + +Let LetNode::make(LocalVar var, Expr value, Expr body, Type value_type) { + std::shared_ptr n = std::make_shared(); + n->var = std::move(var); + n->value = std::move(value); + n->body = std::move(body); + n->value_type = std::move(value_type); + return Let(n); +} + +TVM_REGISTER_API("relay._make.Let") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LetNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { + p->stream << "LetNode(" << node->var << node->value << node->body << node->value_type << ")"; + }); + +If IfNode::make(Expr cond, Expr true_value, Expr false_value) { + std::shared_ptr n = std::make_shared(); + n->cond = std::move(cond); + n->true_value = std::move(true_value); + n->false_value = std::move(false_value); + return If(n); +} + +TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IfNode::make(args[0], args[1], args[2]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { + p->stream << "IfNode(" << + node->cond << ", " << + node->true_value << + node->false_value << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/type.cc b/src/relay/type.cc new file mode 100644 index 000000000000..156207e1b73a --- /dev/null +++ b/src/relay/type.cc @@ -0,0 +1,100 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type.cc + * \brief The type system AST nodes of Relay. + */ +#include "tvm/relay/type.h" +#include "tvm/ir_functor.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +TensorType TensorTypeNode::make(Array shape, DataType dtype) { + std::shared_ptr n = std::make_shared(); + n->shape = std::move(shape); + n->dtype = std::move(dtype); + return TensorType(n); +} + +TVM_REGISTER_API("relay._make.TensorType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Array shape = args[0]; + *ret = TensorTypeNode::make(shape, args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TensorTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape + << ")"; + }); + +TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { + std::shared_ptr n = std::make_shared(); + n->var = tvm::Var(name); + n->kind = std::move(kind); + return TypeParam(n); +} + +TVM_REGISTER_API("relay._make.TypeParam") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[1]; + *ret = + TypeParamNode::make(args[0], static_cast(kind)); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeParamNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeParamNode(" << node->var->name_hint << ", " + << node->kind << ")"; + }); + + +FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints) { + std::shared_ptr n = std::make_shared(); + n->arg_types = std::move(arg_types); + n->ret_type = std::move(ret_type); + n->type_params = std::move(type_params); + n->type_constraints = std::move(type_constraints); + return FuncType(n); +} + +TVM_REGISTER_API("relay._make.FuncType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const FuncTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "FuncTypeNode(" << node->type_params << ", " + << node->arg_types << ", " << node->ret_type << ", " + << node->type_constraints << ")"; + }); + +TypeFunction TypeFunctionNode::make(std::string name, int num_args) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + n->num_args = std::move(num_args); + return TypeFunction(n); +} + +TVM_REGISTER_API("relay._make.TypeFunction") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TypeFunctionNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeFunctionNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeFunctionNode(" << node->name << ", " << node->num_args << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py new file mode 100644 index 000000000000..26fe06109513 --- /dev/null +++ b/tests/python/relay/test_ir_nodes.py @@ -0,0 +1,154 @@ +""" test ir""" +import tvm +from tvm import relay +import tvm.relay.make as mk +from tvm import expr + +# Span + + +def test_span() -> None: + span = mk.Span(None, 1, 1) + assert span.source == None + assert span.lineno == 1 + assert span.col_offset == 1 + assert span.same_as(span) + assert span == span + assert isinstance(span, relay.base.Span) + str(span) + +# Types + + +def test_tensor_type() -> None: + shape = tvm.convert([1, 2, 3]) + dtype = 'float32' + tt = mk.TensorType(shape, dtype) + assert tt.dtype == dtype + assert tt.shape == shape + assert tt.span == None + str(tt) + + +def test_type_param() -> None: + tp = mk.TypeParam('name', relay.Kind.Shape) + tp.kind == relay.Kind.Shape + tp.span # TODO allow us to set span + str(tp) + + +def test_func_type() -> None: + type_params = tvm.convert([]) + type_constraints = tvm.convert([]) # TODO: fill me in + arg_types = tvm.convert([]) + ret_type = None + tf = mk.FuncType(arg_types, ret_type, type_params, type_constraints) + assert tf.type_params == type_params + assert tf.type_constraints == type_constraints + assert tf.arg_types == arg_types + assert tf.ret_type == ret_type + assert tf.span == None + # TODO make sure we can set + str(tf) + + +def test_constant() -> None: + arr = tvm.nd.array(10) + const = mk.Constant(arr) + assert const.data == arr + assert const.span == None + str(const) + + +def test_tuple() -> None: + fields = tvm.convert([]) + tup = mk.Tuple(fields) + assert tup.fields == fields + assert tup.span == None + str(tup) + + +def test_local_var() -> None: + name_hint = 's' + lv = mk.LocalVar(name_hint) + lv.name_hint == name_hint + # assert lv.span == None todo(@jroesch): what do we do about spans + str(lv) + + +def test_global_var() -> None: + name_hint = 'g' + gv = mk.GlobalVar(name_hint) + gv.name_hint == name_hint + # assert lv.span == None todo(@jroesch): what do we do about spans + str(gv) + + +def test_param() -> None: + lv = mk.LocalVar('x') + ty = None + param = mk.Param(lv, ty) + assert param.var == lv + assert param.type == ty + assert param.span == None + str(param) + + +def test_function() -> None: + param_names = ['a', 'b', 'c', 'd'] + params = tvm.convert([mk.Param(mk.LocalVar(n), None) for n in param_names]) + ret_type = None + body = None + type_params = tvm.convert([]) + fn = mk.Function(params, ret_type, body, type_params) + assert fn.params == params + assert fn.body == body + assert fn.type_params == type_params + assert fn.span == None + str(fn) + + +def test_call() -> None: + op = mk.LocalVar('f') + arg_names = ['a', 'b', 'c', 'd'] + args = tvm.convert([mk.LocalVar(n) for n in arg_names]) + call = mk.Call(op, args, None, None) + assert call.op == op + assert call.args == args + assert call.span == None + str(call) + + +def test_let() -> None: + lv = mk.LocalVar('x') + ty = None + arr = tvm.nd.array(10) + value = mk.Constant(arr) + # I would prefer that the order of arguments + # matches syntax let x : t = v in b + let = mk.Let(lv, value, lv, ty) + assert let.var == lv + assert let.value == value + assert let.value_type == ty + assert let.body == lv + assert let.span == None + str(let) + + +def test_if() -> None: + cond = mk.LocalVar('cond') + left = mk.LocalVar('left') + right = mk.LocalVar('right') + ife = mk.If(cond, left, right) + assert ife.cond == cond + assert ife.true_value == left + assert ife.false_value == right + assert ife.span == None + str(ife) + + +if __name__ == "__main__": + test_span() + test_tensor_type() + test_type_param() + test_func_type() From ac1455d0ed72a6c3667604bb830572c37028124a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 16:31:02 -0700 Subject: [PATCH 02/88] Add InternTable data structure --- include/tvm/relay/compiler/intern_table.h | 55 +++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 include/tvm/relay/compiler/intern_table.h diff --git a/include/tvm/relay/compiler/intern_table.h b/include/tvm/relay/compiler/intern_table.h new file mode 100644 index 000000000000..1850e513e5e5 --- /dev/null +++ b/include/tvm/relay/compiler/intern_table.h @@ -0,0 +1,55 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/compiler/intern_table.h + * \brief A table which maps string keys to data. + * + * These are useful for mapping user-readable names + * to globally unique allocations which use pointer + * equality for comparsion. + */ +#ifndef TVM_RELAY_COMPILER_INTERN_TABLE_H_ +#define TVM_RELAY_COMPILER_INTERN_TABLE_H_ + +#include +#include +#include "dmlc/logging.h" + +namespace tvm { +namespace relay { + +struct KeyNotFound : dmlc::Error { + explicit KeyNotFound(std::string msg) : dmlc::Error(msg) {} +}; + +template +class InternTable { +private: + /*! \brief The internal table mapping from strings to T. */ + std::unordered_map table_; + + public: + /*! \brief Insert a new key into the table. + * \note Attempting to reinsert a key triggers an error. + */ + void Insert(const std::string& key, const T& value) { + if (table_.find(key) == table_.end()) { + table_.insert({key, value}); + } else { + throw dmlc::Error( + std::string("you have previously interred a value for: ") + key); + } + } + + /*! \brief Lookup the data in the table. */ + const T& Lookup(std::string key) { + if (table_.find(key) != table_.end()) { + return table_.at(key); + } else { + throw KeyNotFound(std::string("could not find match") + key); + } + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_COMPILER_INTERN_TABLE_H_ From 4a285f8b12fb13fc22096d1679d6b62c584e4842 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 17:56:59 -0700 Subject: [PATCH 03/88] Add placeholder defn of Operator --- include/tvm/relay/expr.h | 2 +- include/tvm/relay/op.h | 47 ++++++++++++++++++++++++++++++++++++++++ src/relay/op.cc | 31 ++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 include/tvm/relay/op.h create mode 100644 src/relay/op.cc diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index b830c7ce04ef..c1dd557717af 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/expr.h - * \brief Relay expression IR Node. + * \brief The Relay IR expression nodes. */ #ifndef TVM_RELAY_EXPR_H_ #define TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h new file mode 100644 index 000000000000..fa152945d38c --- /dev/null +++ b/include/tvm/relay/op.h @@ -0,0 +1,47 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op.h + * \brief Relay's representation of operators. + */ +#ifndef TVM_RELAY_OP_H_ +#define TVM_RELAY_OP_H_ + +#include "./expr.h" + +namespace tvm { +namespace relay { + + +/*! + * \brief A primitive Relay operator defined externally to Relay. + * + * \note Currently these are expected to be backed by a TVM's operator, + * such as the ones defined in TOPI. + * + * For developers who are familar with the computational graph this + * directly maps to the concept of operators in NNVM. + */ +class Operator; +/*! \brief Container for Operator */ +class OperatorNode : public ExprNode { + public: + /*! \brief A type which specifies the relationship between the inputs and outputs + * of the operator. + */ + Type op_type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("op_type", &op_type); + } + + TVM_DLL static Operator make(Type op_type); + + static constexpr const char* _type_key = "relay.Operator"; + TVM_DECLARE_NODE_TYPE_INFO(OperatorNode, OperatorNode); +}; + +RELAY_DEFINE_NODE_REF(Operator, OperatorNode, Expr); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_H_ diff --git a/src/relay/op.cc b/src/relay/op.cc new file mode 100644 index 000000000000..07ad5f0ae4ed --- /dev/null +++ b/src/relay/op.cc @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file op.cc + * \brief Relay's representation of operators. + */ +#include "tvm/relay/op.h" +#include "tvm/ir_functor.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace runtime; + +Operator OperatorNode::make(Type op_type) { + std::shared_ptr n = std::make_shared(); + n->op_type = std::move(op_type); + return Operator(n); +} + +TVM_REGISTER_API("relay._make.Operator").set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = OperatorNode::make(args[0]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const OperatorNode *node, tvm::IRPrinter *p) { + p->stream << "OperatorNode(" << node->op_type << ")"; + }); + +} // namespace relay +} // namespace tvm From 694e95b552601e367c9b9f60c2f4048d97295f37 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 22:37:16 -0700 Subject: [PATCH 04/88] Add initial port of environment.h --- include/tvm/relay/compiler/environment.h | 110 +++++++++ include/tvm/relay/error.h | 28 +++ src/relay/compiler/environment.cc | 292 +++++++++++++++++++++++ 3 files changed, 430 insertions(+) create mode 100644 include/tvm/relay/compiler/environment.h create mode 100644 include/tvm/relay/error.h create mode 100644 src/relay/compiler/environment.cc diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h new file mode 100644 index 000000000000..ddb7f0dca192 --- /dev/null +++ b/include/tvm/relay/compiler/environment.h @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file environment.h + * \brief The global environment containing + */ +#ifndef TVM_RELAY_ENVIRONMENT_H_ +#define TVM_RELAY_ENVIRONMENT_H_ + +#include +#include +#include "tvm/relay/compiler/intern_table.h" +#include "../expr.h" +#include "../type.h" +#include "../op.h" +#include "../error.h" +// #include "tvm/relay/options.h" +// #include "tvm/relay/source_map.h" + +namespace tvm { +namespace relay { + +struct Environment; + +/*! \brief The global environment of Relay programs. + * + * The global environment contains all the global + * information needed to compile a Relay program, + * including the set of operators, the set of + * global functions, and configuration options. + * + * Many operations require acess to the global + * Environment. We mostly pass the argument by value + * in a functional style as an explicit argument. + * + * This means users can construct custom environments + * easily, for example a fresh environment for each + * thread while auto-tuning. + * */ + +class EnvironmentNode : public RelayNode { + private: + /*! A map from string names to GlobalIds, ensures global uniqueness. */ + InternTable global_map_; + /*! A map from string names to Operators, ensures global uniqueness. */ + InternTable operator_map_; + // /*! \brief A map from file names to source fragments. */ + // SourceMap source_map_; + // /*! \brief A list of the errors reported during the current run. */ + // std::vector errors_; + + public: + // This map contains all items *except* operators. + std::unordered_map items; + + // Options options; + + tvm::PackedFunc jit_for(Operator op); + tvm::PackedFunc reverse(Operator op); + + EnvironmentNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final {} + + TVM_DLL static Environment make( + std::unordered_map global_funcs); + + // Add an item to the Enviroment. + // void add(const Operator& op, bool update = false); + // void add(const Operator& op, bool update = false); + + // void try_add(const Item& item, bool update=false); + // void update(const Item& item); + // void remove(const GlobalId& id); + + // GlobalId global_id(const std::string& str); + // OperatorId operator_id(const std::string& str); + + // We can lookup a GlobalId, OperatorId. + // Defn lookup(const GlobalId& id); + // Operator lookup(const OperatorId& id); + // Defn lookup_global(const std::string& str); + // Item lookup_operator(const std::string& str); + // FileId add_source(std::string file_name, std::string source); + + // tvm::Array get_operators(); + // tvm::Array get_defns(); + + // void report_error(std::string msg, Span sp); + // void display_errors(); + // void register_shape_ext(ShapeExtension ext); + + static constexpr const char* _type_key = "relay.Environment"; + TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); +}; + +struct Environment : public NodeRef { + Environment() {} + explicit Environment(std::shared_ptr p) : NodeRef(p) {} + + inline EnvironmentNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = EnvironmentNode; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ENVIRONMENT_H_ diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h new file mode 100644 index 000000000000..d2698f8e380b --- /dev/null +++ b/include/tvm/relay/error.h @@ -0,0 +1,28 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error.h + * \brief The set of errors raised by Relay. + */ +#ifndef TVM_RELAY_ERROR_H_ +#define TVM_RELAY_ERROR_H_ + +#include +#include "./base.h" + +namespace tvm { +namespace relay { + +struct Error : dmlc::Error { + Error(std::string msg) : dmlc::Error(msg) {} +}; + +struct SpannedError { + std::string msg; + Span sp; + SpannedError(std::string msg, Span sp) : msg(msg), sp(sp) {} +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ERROR_H_ diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc new file mode 100644 index 000000000000..125ceae834b3 --- /dev/null +++ b/src/relay/compiler/environment.cc @@ -0,0 +1,292 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file environment.cc + * \brief Relay global environment. + */ +#include +#include "tvm/relay/compiler/environment.h" +// #include "tvm/relay/alpha_eq.h" +// #include "tvm/relay/debug.h" +// #include "tvm/relay/typeck/typechecker.h" +// #include "tvm/relay/util/rang.h" +// #include "tvm/runtime/packed_func_ext.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +Environment EnvironmentNode::make( + std::unordered_map global_funcs) { + std::shared_ptr n = std::make_shared(); + n->items = std::move(global_funcs); + return Environment(n); +} + +// tvm::PackedFunc EnvironmentNode::jit_for(OperatorId id) { +// return this->lookup(id)->compiler; +// } + +// GlobalId EnvironmentNode::global_id(const std::string &str) { +// try { +// return global_map_.Lookup(str); +// } catch (const KeyNotFound &err) { +// GlobalId id = GlobalIdNode::make(str); +// global_map_.Insert(str, id); +// return id; +// } +// } + +// OperatorId EnvironmentNode::operator_id(const std::string &str) { +// try { +// return operator_map_.Lookup(str); +// } catch (const KeyNotFound &err) { +// OperatorId id = OperatorIdNode::make(str); +// operator_map_.Insert(str, id); +// return id; +// } +// } + +// // Add a new item to the global environment +// // throws an exception if the item already +// // exists. +// void EnvironmentNode::add(const Item &unchecked_item, bool update) { +// // Type check the item before we add it to the environment. +// auto env = GetRef(this); +// Item item = check(env, unchecked_item); + +// if (const OperatorNode *op_node = item.as()) { +// Operator op = GetRef(op_node); +// auto type = op->type; +// if (operators.find(op->id) != operators.end()) { +// if (!update) { +// throw dmlc::Error("already have definition for XXXX."); +// } + +// auto old_type = operators[op->id]->type; + +// if (!alpha_eq(type, old_type)) { +// throw dmlc::Error( +// "Environment#update changes type, not possible in this mode."); +// } + +// operators.insert({op->id, op}); +// } else { +// operators.insert({op->id, op}); +// } +// } else if (const DefnNode *d = item.as()) { +// auto def = GetRef(d); +// auto type = def->type; +// if (items.find(def->id) != items.end()) { +// if (!update) { +// throw dmlc::Error("already have definition for XXXX."); +// } + +// auto old_type = items[def->id].as()->type; + +// if (!alpha_eq(type, old_type)) { +// throw dmlc::Error( +// "Environment#update changes type, not possible in this mode."); +// } + +// this->items.insert({def->id, def}); +// } else { +// this->items.insert({def->id, def}); +// } +// } else { +// throw EnvError("internal error: unknown item type, unreachable code"); +// } +// } + +// void EnvironmentNode::update(const Item &item) { return this->add(item, true); } + +// void EnvironmentNode::remove(const GlobalId &id) { this->items.erase(id); } + +// Defn EnvironmentNode::lookup(const GlobalId &id) { +// if (items.find(id) != items.end()) { +// return items.at(id); +// } else { +// throw EnvError(std::string("there is no definition of ") + id->name); +// } +// } + +// Operator EnvironmentNode::lookup(const OperatorId &id) { +// if (operators.find(id) != operators.end()) { +// return operators.at(id); +// } else { +// throw EnvError(std::string("there is no definition of ") + id->name); +// } +// } + +// Item EnvironmentNode::lookup_operator(const std::string &str) { +// OperatorId id = this->operator_id(str); +// return lookup(id); +// } + +// Defn EnvironmentNode::lookup_global(const std::string &str) { +// GlobalId id = this->global_id(str); +// return this->lookup(id); +// } + +// inline FileId EnvironmentNode::add_source(std::string file_name, +// std::string source) { +// return this->source_map_.add_source(file_name, source); +// } + +// void EnvironmentNode::report_error(std::string msg, Span sp) { +// this->errors_.push_back(Error(msg, sp)); +// } + +// void EnvironmentNode::display_errors() { +// for (auto err : this->errors_) { +// auto sp = err.sp; +// auto source_file = this->source_map_.GetSource(err.sp->file_id); +// auto file_name = source_file.file_name; +// auto source_at_span = source_file.SourceAt(err.sp, 1); +// std::string error_marker = "error:"; +// auto line_info = +// std::to_string(sp->lineno) + ":" + std::to_string(sp->col_offset); + +// std::cout << rang::style::bold << rang::fg::red << error_marker +// << rang::fg::reset << file_name << ":" << line_info +// << rang::style::reset << " " << source_at_span << std::endl; + +// // Build the cursor. + +// // Fix this code, hardwired to compute alignment of pointer. +// size_t spaces = error_marker.size() + line_info.size() + file_name.size() + +// sp->col_offset - 3; + +// std::string cursor = "~~~~^~~~~"; +// for (size_t i = 0; i < spaces; i++) { +// std::cout << " "; +// } +// std::cout << rang::fg::red << cursor << " " << err.msg << rang::style::reset +// << std::endl; +// } +// } + +// Array EnvironmentNode::get_operators() { +// std::vector ops; +// for (auto pair : this->operators) { +// ops.push_back(pair.second); +// } +// return Array(ops); +// } + +// Array EnvironmentNode::get_defns() { +// std::vector defns; +// for (auto pair : this->items) { +// defns.push_back(pair.second); +// } +// return Array(defns); +// } + +// void EnvironmentNode::register_shape_ext(ShapeExtension ext) { +// this->shape_exts_.Insert(ext->name, ext); +// } + +// TVM_REGISTER_API("relay._make.Environment") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// *ret = EnvironmentNode::make({}); +// }); + +// TVM_REGISTER_API("relay._env.Environment_add") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// Item item = args[1]; +// env->add(item, true); // REMOVE ME +// }); + +// TVM_REGISTER_API("relay._env.Environment_lookup_global") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// GlobalId id = args[1]; +// *ret = env->lookup(id); +// }); + +// TVM_REGISTER_API("relay._env.Environment_lookup_operator") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// OperatorId id = args[1]; +// *ret = env->lookup(id); +// }); + +// // TVM_REGISTER_API("relay._env.Environment_remove_global") +// // .set_body([](TVMArgs args, TVMRetValue *ret) { +// // Environment env = args[0]; +// // GlobalId id = args[1]; +// // env->remove(id); +// // }); + +// TVM_REGISTER_API("relay._env.Environment_global_id") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string str = args[1]; +// *ret = env->global_id(str); +// }); + +// TVM_REGISTER_API("relay._env.Environment_operator_id") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string str = args[1]; +// *ret = env->operator_id(str); +// }); + +// TVM_REGISTER_API("relay._env.Environment_register_shape_ext") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// ShapeExtension ext = args[1]; +// env->register_shape_ext(ext); +// }); + +// TVM_REGISTER_API("relay._env.Environment_register_primitive") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string str = args[1]; +// *ret = env->global_id(str); +// }); + +// TVM_REGISTER_API("relay._env.Environment_add_source") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string file_name = args[1]; +// std::string source_name = args[2]; +// *ret = env->add_source(file_name, source_name); +// }); + +// TVM_REGISTER_API("relay._env.Environment_report_error") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string msg = args[1]; +// Span sp = args[2]; +// env->report_error(msg, sp); +// }); + +// TVM_REGISTER_API("relay._env.Environment_display_errors") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// return env->display_errors(); +// }); + +// TVM_REGISTER_API("relay._env.Environment_get_operators") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// *ret = env->get_operators(); +// }); + +// TVM_REGISTER_API("relay._env.Environment_get_defns") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// *ret = env->get_defns(); +// }); + +// TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +// .set_dispatch([](const EnvironmentNode *node, +// tvm::IRPrinter *p) { +// p->stream << "EnvironmentNode(todo)"; // << node->items << ")"; +// }); + +} // namespace relay +} // namespace tvm From bd4544337392a7042e17ea928a742c5905f7d710 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 22:44:16 -0700 Subject: [PATCH 05/88] Add expr_functor.h --- include/tvm/relay/expr_functor.h | 143 +++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 include/tvm/relay/expr_functor.h diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h new file mode 100644 index 000000000000..922892e8a7a5 --- /dev/null +++ b/include/tvm/relay/expr_functor.h @@ -0,0 +1,143 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr_functor.h + * \brief A more powerful Visitor that enables defining arbitrary function + * signatures with dispatch on first argument. + */ +#ifndef TVM_RELAY_EXPR_FUNCTOR_H_ +#define TVM_RELAY_EXPR_FUNCTOR_H_ + +#include +#include +#include "ir.h" + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * This helps you to avoid to book-keep return value of Visitor via state, + * which can cause bugs easily when state is incorrectly maintained. + * + * \code + * // A functor that set variable to b. and calculate results. + * class MyExprFunctor + * : public ir::ExprFunctor { + * public: + * int VisitExpr_(const Variable* op, int b) final { + * return b; + * } + * int VisitExpr_(const IntImm* op, int b) final { + * return op->value; + * } + * int VisitExpr_(const Add* op, int b) final { + * return Visit(op->a, b) + Visit(op->b, b); + * } + * }; + * MyExprFunctor f; + * Var x("x"); + * CHECK_EQ(f(x + 1, 2), 3); + * \endcode + * + * \note Why do we need this more powerful Functor: + * + * We often need to implement a transformer tasks. + * Say we want to take Expr and transform it to some analysis result, + * This easily be done incorrectly using plain Visitor. See IRVisitor's + * document for possible error cases. + * + * \tparam FType function signiture + * This type if only defined for FType with function signiture R(const Expr&, + * Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } + +#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { + return VisitExpr(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitExpr_(const ConstantNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LocalVarNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OperatorNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Node* op, Args...) { + throw dmlc::Error(std::string("Do not have a default for ") + op->type_key()); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); + RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); + RELAY_EXPR_FUNCTOR_DISPATCH(LocalVarNode); + RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); + RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode); + RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); + RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); + RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); + RELAY_EXPR_FUNCTOR_DISPATCH(OperatorNode); + return vtable; + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_FUNCTOR_H_ From bbb7b7ea1dfebe3b5c05ed89e509b1dab906e1a4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 22:47:14 -0700 Subject: [PATCH 06/88] Add initial version of type_functor.h --- include/tvm/relay/compiler/type_functor.h | 93 +++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 include/tvm/relay/compiler/type_functor.h diff --git a/include/tvm/relay/compiler/type_functor.h b/include/tvm/relay/compiler/type_functor.h new file mode 100644 index 000000000000..66454725db48 --- /dev/null +++ b/include/tvm/relay/compiler/type_functor.h @@ -0,0 +1,93 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_functor.h + * \brief A way to defined arbitrary function signature with dispatch on types. + */ +#ifndef TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ +#define TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ + +#include +#include "ir.h" + +namespace tvm { +namespace relay { + +template +class TypeFunctor; + +// functions to be overriden. +#define TYPE_FUNCTOR_DEFAULT \ + { return VisitTypeDefault_(op, std::forward(args)...); } + +#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitType_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class TypeFunctor { + private: + using TSelf = TypeFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~TypeFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Type& n, Args... args) { + return VisitType(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitType(const Type& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitType_(const TensorTypeNode* op, + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeFunction* op, Args... args) TYPE_FUNCTOR_DEFAULT; + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + + virtual R VisitTypeDefault_(const Node* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->type_key(); + return R(); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); + RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeFunctionNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + return vtable; + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ From 697ba97fe50bf84c71bf38a033f1d5838349c1a9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 22:53:00 -0700 Subject: [PATCH 07/88] Make type_functor.h a private header --- {include/tvm => src}/relay/compiler/type_functor.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {include/tvm => src}/relay/compiler/type_functor.h (100%) diff --git a/include/tvm/relay/compiler/type_functor.h b/src/relay/compiler/type_functor.h similarity index 100% rename from include/tvm/relay/compiler/type_functor.h rename to src/relay/compiler/type_functor.h From a4c7df1e5d588f34510f892170257382b60d67ab Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 23:05:16 -0700 Subject: [PATCH 08/88] Add ir.h --- include/tvm/relay/ir.h | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 include/tvm/relay/ir.h diff --git a/include/tvm/relay/ir.h b/include/tvm/relay/ir.h new file mode 100644 index 000000000000..73c275cf1c98 --- /dev/null +++ b/include/tvm/relay/ir.h @@ -0,0 +1,20 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/ir.h + * \brief The Relay intermediate representation's core data structures. + */ +#ifndef TVM_RELAY_IR_H_ +#define TVM_RELAY_IR_H_ + +#include "./base.h" +#include "./type.h" +#include "./expr.h" +#include "./op.h" + +// namespace tvm { +// namespace relay { + +// } // namespace relay +// } // namespace tvm + +#endif // TVM_RELAY_IR_H_ From f1b9e925300c52087d1b6fcde7e82568c93e0eb8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 23:07:02 -0700 Subject: [PATCH 09/88] Add back Relay's logging.h --- include/tvm/relay/logging.h | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 include/tvm/relay/logging.h diff --git a/include/tvm/relay/logging.h b/include/tvm/relay/logging.h new file mode 100644 index 000000000000..99cfc44de6cb --- /dev/null +++ b/include/tvm/relay/logging.h @@ -0,0 +1,33 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/logging.h + * \brief A wrapper around dmlc-core/logging.h which adds the ability + * to toggle logging via an environment variable. + */ + +#ifndef TVM_RELAY_LOGGING_H_ +#define TVM_RELAY_LOGGING_H_ + +#include +#include +#include +#include "dmlc/logging.h" + +namespace tvm { +namespace relay { + +static bool logging_enabled() { + if (auto var = std::getenv("RELAY_LOG")) { + std::string is_on(var); + return is_on == "1"; + } else { + return false; + } +} + +#define RELAY_LOG(severity) LOG_IF(severity, logging_enabled()) + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_LOGGING_H_ From de80240d477750a6490a464587e1e1b9f51876fe Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 23:14:49 -0700 Subject: [PATCH 10/88] Add type checker header --- include/tvm/relay/compiler/typechecker.h | 25 ++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 include/tvm/relay/compiler/typechecker.h diff --git a/include/tvm/relay/compiler/typechecker.h b/include/tvm/relay/compiler/typechecker.h new file mode 100644 index 000000000000..c71f78c1a5b0 --- /dev/null +++ b/include/tvm/relay/compiler/typechecker.h @@ -0,0 +1,25 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file tvm/relay/typechecker.h + * \brief Type check a Relay program producing a type checked program + * with its checked_type field populated and incomplete types resolved. + */ +#ifndef TVM_RELAY_COMPILER_TYPECHECKER_H_ +#define TVM_RELAY_COMPILER_TYPECHECKER_H_ + +#include "tvm/relay/ir.h" +#include "tvm/relay/environment.h" + +namespace tvm { +namespace relay { + +/*! The result of type checking an expression is a new expression + * with unambigous type information filled in, as well as it's + * checked type field populated with the result type. + */ +Expr check(const Environment & env, const Expr & e); +Operator check(const Environment & env, const Operator & op); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_COMPILER_TYPECHECKER_H_ From ebc1bf2e0d27c1ab51816fb56757ac2fa58b0653 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 00:04:16 -0700 Subject: [PATCH 11/88] Add alpha_eq --- include/tvm/relay/compiler/alpha_eq.h | 19 + src/relay/compiler/alpha_eq.cc | 284 +++++++++++++ tests/python/relay/test_alpha_eq.py | 576 ++++++++++++++++++++++++++ 3 files changed, 879 insertions(+) create mode 100644 include/tvm/relay/compiler/alpha_eq.h create mode 100644 src/relay/compiler/alpha_eq.cc create mode 100644 tests/python/relay/test_alpha_eq.py diff --git a/include/tvm/relay/compiler/alpha_eq.h b/include/tvm/relay/compiler/alpha_eq.h new file mode 100644 index 000000000000..ba91afc21015 --- /dev/null +++ b/include/tvm/relay/compiler/alpha_eq.h @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/alpha_eq.h + * \brief Check expressions & types for structural equivalence. + */ +#ifndef TVM_RELAY_ALPHA_EQ_H_ +#define TVM_RELAY_ALPHA_EQ_H_ + +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +bool alpha_eq(const Expr & e1, const Expr & e2); +bool alpha_eq(const Type & t1, const Type & t2); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ALPHA_EQ_H_ diff --git a/src/relay/compiler/alpha_eq.cc b/src/relay/compiler/alpha_eq.cc new file mode 100644 index 000000000000..4b8e904bf29e --- /dev/null +++ b/src/relay/compiler/alpha_eq.cc @@ -0,0 +1,284 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file alpha_eq.cc + * \brief Compute the set of variables not bound in the expression. + */ +#include "tvm/relay/compiler/alpha_eq.h" +#include "tvm/relay/expr_visitor.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +struct TypeAlphaEq : TypeVisitor { + tvm::Map eq_map; + bool equal; + + TypeAlphaEq() : eq_map(), equal(true) {} + + void DataTypeEqual(const DataType & dt1, const DataType & dt2) { + equal = equal && dt1 == dt2; + } + void ShapeEqual(Array s1, Array s2) { + } + + void VisitType_(const TensorTypeNode *tt1, const Type &t2) override { + if (const TensorTypeNode *tt2 = t2.as()) { + DataTypeEqual(tt1->dtype, tt2->dtype); + ShapeEqual(tt1->shape, tt2->shape); + } else { + equal = false; + } + } + +// void VisitType_(const TypeVarNode *bt1, const Type &t2) override { +// if (const TypeVarNode *bt2 = t2.as()) { +// equal = equal && bt1 == bt2; +// return; +// } else { +// equal = false; +// } +// } + + void VisitType_(const TypeParamNode *ti1, const Type &t2) override { + if (const TypeParamNode *ti2 = t2.as()) { + auto tid1 = GetRef(ti1); + auto tid2 = GetRef(ti2); + + // We handle open terms with this rule assuming variables are identical. + // + // Not sure if we should do this. + if (tid1 == tid2) { + return; + } + + // Check that they are same kind + if (tid1->kind != tid2->kind) { + equal = false; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(tid1) != eq_map.end()) { + equal = equal && eq_map[tid1] == tid2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitType_(const FuncTypeNode *op, const Type &t2) override { + if (const FuncTypeNode *ta2 = t2.as()) { + if (op->arg_types.size() != ta2->arg_types.size()) { + equal = false; + return; + } + + for (size_t i = 0; i < op->arg_types.size(); i++) { + this->VisitType(op->arg_types[i], ta2->arg_types[i]); + if (!equal) { + return; + } + } + + this->VisitType(op->ret_type, ta2->ret_type); + } else { + equal = false; + } + } + + void VisitType_(const TypeFunctionNode *op, const Type &t2) override { + } +// void VisitType_(const TupleTypeNode *op, const Type &t2) override { +// if (const TupleTypeNode *pt = t2.as()) { +// if (op->fields.size() != pt->fields.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < op->fields.size(); i++) { +// if (!equal) { +// return; +// } +// this->VisitType(op->fields[i], pt->fields[i]); +// } +// } else { +// equal = false; +// } +// } + +// void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { +// TypeCall tycall = GetRef(tyn1); +// if (const TypeCallNode *tyn2 = t2.as()) { +// if (tycall->func != tyn2->func) { +// equal = false; +// return; +// } + +// if (tycall->args.size() != tyn2->args.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < tycall->args.size(); i++) { +// this->VisitType(tycall->args[i], tyn2->args[i]); +// } +// } else { +// equal = false; +// } +// } +}; + +bool alpha_eq(const Type &t1, const Type &t2) { + TypeAlphaEq aeq; + aeq.VisitType(t1, t2); + return aeq.equal; +} + +// struct AlphaEq : ExprVisitor { +// public: +// tvm::Map eq_map; +// bool equal; +// AlphaEq() : eq_map(), equal(true) {} + +// void VisitExpr_(const LocalIdNode *e1, const Expr &e2) override { +// if (const LocalIdNode *id2 = e2.as()) { +// auto local1 = GetRef(e1); +// auto local2 = GetRef(id2); +// // +// // We handle open terms with this rule assuming variables are identical. +// // +// // Not sure if we should do this. +// if (local1 == local2) { +// equal = true; +// return; +// } + +// // Next we see if there is mapping for local1 into the rhs term. +// // If there is we check to see if those are equal. +// if (eq_map.find(local1) != eq_map.end()) { +// equal = equal && eq_map[local1] == local2; +// } else { +// equal = false; +// } +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const GlobalIdNode *g1, const Expr &e2) override { +// if (const GlobalIdNode *g2 = e2.as()) { +// equal = equal && g1 == g2; +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const OperatorIdNode *i1, const Expr &e2) override { +// if (const OperatorIdNode *i2 = e2.as()) { +// equal = equal && i1 == i2; +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const TupleNode *pl1, const Expr &e2) override { +// Tuple prod1 = GetRef(pl1); +// if (const TupleNode *pl2 = e2.as()) { +// Tuple prod2 = GetRef(pl2); +// if (prod1->fields.size() != prod2->fields.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < prod1->fields.size(); i++) { +// this->VisitExpr(prod1->fields[i], prod2->fields[i]); +// } +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const ParamNode *p1, const Expr &e2) override { +// if (const ParamNode *p2 = e2.as()) { +// eq_map.Set(p1->id, p2->id); +// equal = equal && alpha_eq(p1->type, p2->type); +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const FunctionNode *func1, const Expr &e2) override { +// if (const FunctionNode *func2 = e2.as()) { +// if (func1->params.size() != func2->params.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < func1->params.size(); i++) { +// this->VisitExpr(func1->params[i], func2->params[i]); +// } + +// this->VisitExpr(func1->body, func2->body); +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const CallNode *op, const Expr &e2) override { +// if (const CallNode *call = e2.as()) { +// this->VisitExpr(op->fn, call->fn); + +// if (op->args.size() != call->args.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < op->args.size(); i++) { +// this->VisitExpr(op->args[i], call->args[i]); +// } + +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const LetNode *op, const Expr &e2) override { +// if (const LetNode *let = e2.as()) { +// eq_map.Set(op->id, let->id); +// this->VisitExpr(op->value, let->value); +// this->VisitExpr(op->body, let->body); +// } else { +// equal = false; +// } +// } +// }; + +// bool alpha_eq(const Expr &e1, const Expr &e2) { +// AlphaEq eq; +// eq.VisitExpr(e1, e2); +// return eq.equal; +// } + +// // TODO(@jroesch): move to correct namespace? +// TVM_REGISTER_API("relay._make._alpha_eq") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Expr e1 = args[0]; +// Expr e2 = args[1]; +// *ret = alpha_eq(e1, e2); +// }); + +TVM_REGISTER_API("relay._make._type_alpha_eq") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Type t1 = args[0]; + Type t2 = args[1]; + *ret = alpha_eq(t1, t2); + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_alpha_eq.py b/tests/python/relay/test_alpha_eq.py new file mode 100644 index 000000000000..f1dc81c3c483 --- /dev/null +++ b/tests/python/relay/test_alpha_eq.py @@ -0,0 +1,576 @@ +"""Test alpha-equivalence of expressions and types.""" +# pylint: disable=invalid-name, missing-docstring +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * +from relay.ir import alpha_eq, ShapeOp, Kind +from relay.typing import TYPE_DEFAULTS +from relay import ir + +INT_TYPE_WIDTH = TYPE_DEFAULTS["INT_WIDTH"] +INT_TYPE_LANES = TYPE_DEFAULTS["INT_LANES"] + +def int_type(width=32) -> ir.Type: + return TensorType(IntType(width), ShapeSeq([])) + +def float_type(width=32) -> ir.Type: + return TensorType(FloatType(width), ShapeSeq([])) + +def bool_type() -> ir.Type: + return TensorType(BoolType(), ShapeSeq([])) + +def nest_quantifiers(ids, body) -> ir.Type: + ret = body + for tid in reversed(ids): + ret = TypeQuantifier(tid, ret) + return ret + +def test_local_id_not_eq() -> None: + assert not alpha_eq(LocalId("x"), LocalId("y")) + +def test_local_id_eq() -> None: + x = LocalId("x") + assert alpha_eq(x, x) + +def test_global_id_not_eq() -> None: + left = GlobalId("xyz") + right = GlobalId("xyz") + assert not alpha_eq(left, right) + +def test_global_id_eq() -> None: + ident = GlobalId("xyz") + assert alpha_eq(ident, ident) + +def test_operator_id_not_eq() -> None: + left = OperatorId("xyz") + right = OperatorId("xyz") + # equality on operator id is pointer equality + assert not alpha_eq(left, right) + +def test_operator_id_eq() -> None: + x = OperatorId("xyz") + assert alpha_eq(x, x) + +def test_float_literal_eq() -> None: + x = FloatLit(1.0) + y = FloatLit(1.0) + assert alpha_eq(x, y) + +def test_float_literal_not_eq() -> None: + x = FloatLit(1.0) + y = FloatLit(2.0) + assert not alpha_eq(x, y) + +def test_int_literal_eq() -> None: + x = IntLit(1) + y = IntLit(1) + assert alpha_eq(x, y) + +def test_int_literal_not_eq() -> None: + x = IntLit(1) + y = IntLit(2) + assert not alpha_eq(x, y) + +def test_bool_literal_eq() -> None: + x = BoolLit(True) + y = BoolLit(True) + assert alpha_eq(x, y) + +def test_bool_literal_not_eq() -> None: + x = BoolLit(True) + y = BoolLit(False) + assert not alpha_eq(x, y) + +def test_tensor_literal_eq() -> None: + x = TensorLit([IntLit(1), IntLit(2)]) + y = TensorLit([IntLit(1), IntLit(2)]) + assert alpha_eq(x, y) + +def test_tensor_literal_not_eq() -> None: + x = TensorLit([IntLit(1), IntLit(2)]) + y = TensorLit([IntLit(1), IntLit(3)]) + z = TensorLit([IntLit(1)]) + assert not alpha_eq(x, y) + assert not alpha_eq(x, z) + +def test_product_literal_eq() -> None: + x = Tuple([IntLit(1), IntLit(2)]) + y = Tuple([IntLit(1), IntLit(2)]) + assert alpha_eq(x, y) + +def test_product_literal_not_eq() -> None: + x = Tuple([IntLit(1), IntLit(2)]) + y = Tuple([IntLit(2), IntLit(2)]) + z = Tuple([IntLit(1), IntLit(2), IntLit(3)]) + assert not alpha_eq(x, y) + assert not alpha_eq(x, z) + +def test_projection_eq() -> None: + prod = Tuple([IntLit(3), FloatLit(3.5)]) + + assert alpha_eq(Projection(prod, 0), Projection(prod, 0)) + assert alpha_eq(Projection(prod, 1), Projection(prod, 1)) + +def test_projection_not_eq() -> None: + prod1 = Tuple([IntLit(3), IntLit(4)]) + prod2 = Tuple([IntLit(3)]) + prod3 = Tuple([IntLit(3), IntLit(4), FloatLit(3.5)]) + + assert not alpha_eq(Projection(prod1, 0), Projection(prod1, 1)) + assert not alpha_eq(Projection(prod1, 0), Projection(prod2, 0)) + assert not alpha_eq(Projection(prod1, 0), Projection(prod3, 0)) + assert not alpha_eq(Projection(prod1, 1), Projection(prod3, 1)) + +def test_cast_not_eq() -> None: + left = Cast(IntType(1), IntLit(2)) + right = Cast(IntType(1), IntLit(1)) + assert not alpha_eq(left, right) + + # same literal, different type + left = Cast(IntType(1), IntLit(2)) + right = Cast(IntType(2), IntLit(2)) + assert not alpha_eq(left, right) + +def test_cast_eq() -> None: + left = Cast(IntType(1), IntLit(2)) + right = Cast(IntType(1), IntLit(2)) + assert alpha_eq(left, right) + +def test_param_not_eq() -> None: + left = Param(LocalId("foo"), int_type()) + right = Param(LocalId("foo"), bool_type()) + assert not alpha_eq(left, right) + +def test_param_eq() -> None: + left = Param(LocalId("foo"), int_type()) + right = Param(LocalId("bar"), int_type()) + assert alpha_eq(left, right) + +def test_function_not_eq() -> None: + params1 = [Param(LocalId("x"), int_type())] + fn1 = Function([], params1, int_type(), LocalId("x")) + params2 = [Param(LocalId("y"), bool_type())] + fn2 = Function([], params2, int_type(), LocalId("y")) + assert not alpha_eq(fn1, fn2) + + params3 = [Param(LocalId("x"), int_type()), Param(LocalId("y"), int_type())] + fn3 = Function([], params3, int_type(), LocalId("z")) + assert not alpha_eq(fn1, fn3) + +def test_function_eq() -> None: + x = LocalId("x") + y = LocalId("y") + params1 = [Param(x, int_type())] + fn1 = Function([], params1, int_type(), x) + params2 = [Param(y, int_type())] + fn2 = Function([], params2, int_type(), y) + assert alpha_eq(fn1, fn2) + +def test_call_not_eq() -> None: + x = LocalId("x") + y = LocalId("y") + params1 = [Param(x, int_type())] + fn1 = Function([], params1, int_type(), x) + args1 = [IntLit(1)] + call1 = Call(fn1, args1) + + args2 = [IntLit(2)] + call2 = Call(fn1, args2) + assert not alpha_eq(call1, call2) + + params2 = [Param(y, int_type())] + fn2 = Function([], params2, float_type(), FloatLit(0.0)) + call3 = Call(fn2, args1) + assert not alpha_eq(call1, call3) + assert not alpha_eq(call2, call3) + +def test_call_eq() -> None: + x = LocalId("x") + y = LocalId("y") + params1 = [Param(x, int_type())] + fn1 = Function([], params1, int_type(), x) + args = [IntLit(1)] + call1 = Call(fn1, args) + + params2 = [Param(y, int_type())] + fn2 = Function([], params2, int_type(), y) + call2 = Call(fn2, args) + assert alpha_eq(call1, call2) + +def test_debug_not_eq() -> None: + left = Debug(IntLit(1)) + right = Debug(IntLit(2)) + assert not alpha_eq(left, right) + +def test_debug_eq() -> None: + left = Debug(IntLit(1)) + right = Debug(IntLit(1)) + assert alpha_eq(left, right) + +def test_let_not_eq() -> None: + x = LocalId("x") + y = LocalId("y") + let1 = Let(x, int_type(), IntLit(10), IntLit(11)) + let2 = Let(y, int_type(), IntLit(10), IntLit(12)) + assert not alpha_eq(let1, let2) + + let3 = Let(x, int_type(), IntLit(10), x) + let4 = Let(y, int_type(), IntLit(12), y) + assert not alpha_eq(let3, let4) + +def test_let_eq() -> None: + x = LocalId("x") + y = LocalId("y") + let1 = Let(x, int_type(), IntLit(10), x) + let2 = Let(y, int_type(), IntLit(10), y) + assert alpha_eq(let1, let2) + +def test_ref_eq() -> None: + r1 = Ref(IntLit(5)) + r2 = Ref(IntLit(5)) + assert alpha_eq(r1, r2) + +def test_ref_not_eq() -> None: + r1 = Ref(IntLit(5)) + r2 = Ref(FloatLit(3.5)) + r3 = Ref(r1) + assert not alpha_eq(r1, r2) + assert not alpha_eq(r1, r3) + assert not alpha_eq(r2, r3) + +def test_val_ref_eq() -> None: + vr1 = ReadRef(Ref(IntLit(35))) + vr2 = ReadRef(Ref(Tuple([IntLit(12), FloatLit(2.5)]))) + assert alpha_eq(vr1, vr1) + assert alpha_eq(vr2, vr2) + +def test_val_ref_not_eq() -> None: + vr1 = ReadRef(Ref(IntLit(5))) + vr2 = ReadRef(Ref(vr1)) + vr3 = ReadRef(Ref(FloatLit(5.0))) + assert not alpha_eq(vr1, vr2) + assert not alpha_eq(vr1, vr3) + assert not alpha_eq(vr2, vr3) + +def test_set_ref_eq() -> None: + sr1 = WriteRef(Ref(FloatLit(5.0)), FloatLit(6.0)) + sr2 = WriteRef(Ref(Tuple([IntLit(3), BoolLit(False)])), + Tuple([IntLit(5), BoolLit(True)])) + assert alpha_eq(sr1, sr1) + assert alpha_eq(sr2, sr2) + +def test_set_ref_not_eq() -> None: + r1 = Ref(FloatLit(5.0)) + r2 = Ref(IntLit(5)) + r3 = Ref(IntLit(6)) + + assert not alpha_eq(WriteRef(r1, FloatLit(6.0)), + WriteRef(r2, IntLit(6))) + assert not alpha_eq(WriteRef(r2, IntLit(6)), WriteRef(r2, IntLit(7))) + assert not alpha_eq(WriteRef(r2, IntLit(7)), WriteRef(r3, IntLit(7))) + +# Type alpha-equality tests + +def test_base_type_eq() -> None: + assert alpha_eq(IntType(32), IntType(32)) + assert alpha_eq(BoolType(), BoolType()) + assert alpha_eq(FloatType(32), FloatType(32)) + +def test_tensor_type_eq() -> None: + tt1 = TensorType( + IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) + tt2 = TensorType( + FloatType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) + assert alpha_eq(tt1, tt1) + assert alpha_eq(tt2, tt2) + +def test_tensor_type_not_eq() -> None: + tt1 = TensorType( + IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) + tt2 = TensorType( + FloatType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) + tt3 = TensorType( + IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) + assert not alpha_eq(tt1, tt2) + assert not alpha_eq(tt1, tt3) + +def test_ref_type_eq() -> None: + rt1 = RefType(int_type()) + rt2 = RefType(float_type()) + assert alpha_eq(rt1, rt1) + assert alpha_eq(rt2, rt2) + +def test_ref_type_not_eq() -> None: + rt1 = RefType(int_type()) + rt2 = RefType(float_type()) + assert not alpha_eq(rt1, rt2) + +def test_product_type_eq() -> None: + pt1 = TupleType([int_type(), RefType(float_type())]) + pt2 = TupleType([float_type(), float_type(), int_type()]) + assert alpha_eq(pt1, pt1) + assert alpha_eq(pt2, pt2) + +def test_product_type_not_eq() -> None: + pt1 = TupleType([int_type(), int_type()]) + pt2 = TupleType([int_type(), int_type(), float_type()]) + pt3 = TupleType([bool_type(), float_type()]) + assert not alpha_eq(pt1, pt2) + assert not alpha_eq(pt1, pt3) + +def test_type_id_eq() -> None: + id1 = TypeParam("id1", Kind.Shape) + id2 = TypeParam("id2", Kind.BaseType) + id3 = TypeParam("id2", Kind.Type) + + assert alpha_eq(id1, id1) + assert alpha_eq(id2, id2) + assert alpha_eq(id3, id3) + +def test_type_id_not_eq() -> None: + # name is just a hint, we use pointer equality as the rule + # (unless there is a quantifier to give context) + id1 = TypeParam("id1", Kind.Shape) + id2 = TypeParam("id1", Kind.Shape) + id3 = TypeParam("id3", Kind.BaseType) + + assert not alpha_eq(id1, id2) + assert not alpha_eq(id1, id3) + +def test_arrow_type_eq() -> None: + ar1 = TypeArrow([int_type()], bool_type()) + ar2 = TypeArrow([int_type(), int_type()], TupleType([])) + assert alpha_eq(ar1, ar1) + assert alpha_eq(ar2, ar2) + +def test_arrow_type_not_eq() -> None: + t1 = int_type() + t2 = bool_type() + t3 = [int_type(), bool_type()] + + assert not alpha_eq(TypeArrow([t1], t2), TypeArrow([t1], t1)) + assert not alpha_eq(TypeArrow(t3, t1), TypeArrow([t2], t1)) + assert not alpha_eq(TypeArrow([t1], TypeArrow([t1], t1)), + TypeArrow([t1], t1)) + +def test_type_quantifier_eq() -> None: + id1 = TypeParam("id1", Kind.Shape) + id2 = TypeParam("id2", Kind.Shape) + tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) + tq2 = TypeQuantifier(id2, TensorType(IntType(32), id2)) + + assert alpha_eq(tq1, tq1) + assert alpha_eq(tq1, tq2) + +def test_nested_type_quantifier_eq() -> None: + id1 = TypeParam("id1", Kind.BaseType) + id2 = TypeParam("id2", Kind.Shape) + id3 = TypeParam("id3", Kind.BaseType) + id4 = TypeParam("id4", Kind.Shape) + tq1 = TypeQuantifier(id1, TypeQuantifier(id2, TensorType(id1, id2))) + tq2 = TypeQuantifier(id3, TypeQuantifier(id4, TensorType(id3, id4))) + + assert alpha_eq(tq1, tq1) + assert alpha_eq(tq1, tq2) + +def test_type_quantifier_not_eq() -> None: + id1 = TypeParam("id1", Kind.Shape) + id2 = TypeParam("id2", Kind.BaseType) + id3 = TypeParam("id3", Kind.Shape) + + tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) + tq2 = TypeQuantifier(id2, TensorType(id2, ShapeSeq([ShapeSingleton(3)]))) + tq3 = TypeQuantifier(id1, TensorType(IntType(32), id3)) + tq4 = TypeQuantifier(id1, TensorType(FloatType(32), id1)) + + assert not alpha_eq(tq1, tq2) + assert not alpha_eq(tq1, tq3) + assert not alpha_eq(tq1, tq4) + assert not alpha_eq(tq2, tq3) + assert not alpha_eq(tq2, tq4) + +def test_shape_singleton_eq() -> None: + single1 = ShapeSingleton(10) + single2 = ShapeSingleton(10) + + assert alpha_eq(single1, single1) + assert alpha_eq(single1, single2) + +def test_shape_singelton_not_eq() -> None: + single1 = ShapeSingleton(10) + single2 = ShapeSingleton(11) + + assert not alpha_eq(single1, single2) + +def test_shape_attr_eq() -> None: + attr1 = ShapeAttr("x") + attr2 = ShapeAttr("x") + + assert alpha_eq(attr1, attr1) + assert alpha_eq(attr1, attr2) + +def test_shape_attr_not_eq() -> None: + id1 = "x" + id2 = "y" + attr1 = ShapeAttr(id1) + attr2 = ShapeAttr(id2) + + assert not alpha_eq(attr1, attr2) + +def test_shape_seq_eq() -> None: + empty = ShapeSeq([]) + seq1 = ShapeSeq([ShapeSingleton(5)]) + seq2 = ShapeSeq([ShapeSingleton(5)]) + + assert alpha_eq(empty, empty) + assert alpha_eq(seq1, seq2) + +def test_shape_seq_not_eq() -> None: + empty = ShapeSeq([]) + seq = ShapeSeq([ShapeSingleton(5)]) + single = ShapeSingleton(5) + + assert not alpha_eq(empty, seq) + assert not alpha_eq(seq, single) + +def test_shape_projection_eq() -> None: + proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + + assert alpha_eq(proj1, proj2) + +def test_shape_projection_not_eq() -> None: + proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 1) + proj3 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 0) + proj4 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 1) + + assert not alpha_eq(proj1, proj2) + assert not alpha_eq(proj1, proj3) + assert not alpha_eq(proj1, proj4) + assert not alpha_eq(proj2, proj3) + assert not alpha_eq(proj2, proj4) + assert not alpha_eq(proj3, proj4) + +def test_shape_binary_op_eq() -> None: + empty = ShapeSeq([]) + single = ShapeSingleton(5) + seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + + op1 = ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty) + op2 = ShapeBinaryOp(ShapeOp.SHSUB, single, single) + op3 = ShapeBinaryOp(ShapeOp.SHMUL, seq, seq) + op4 = ShapeBinaryOp(ShapeOp.SHDIV, seq, seq) + + assert alpha_eq(op1, op1) + assert alpha_eq(op2, op2) + assert alpha_eq(op3, op3) + assert alpha_eq(op4, op4) + +def test_shape_binary_op_not_eq() -> None: + empty = ShapeSeq([]) + single = ShapeSingleton(5) + seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + + assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), empty) + assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHMUL, seq, ShapeSingleton(1)), seq) + assert not alpha_eq( + ShapeBinaryOp(ShapeOp.SHPLUS, single, single), + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([single]), + ShapeSeq([single]))) + assert not alpha_eq( + ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), + ShapeBinaryOp(ShapeOp.SHSUB, empty, empty)) + assert not alpha_eq( + ShapeBinaryOp(ShapeOp.SHMUL, empty, empty), + ShapeBinaryOp(ShapeOp.SHDIV, empty, empty)) + +def test_shape_nested_in_quantifier() -> None: + b1 = TypeParam("b", Kind.BaseType) + x1 = TypeParam("x", Kind.Shape) + y1 = TypeParam("y", Kind.Shape) + + b2 = TypeParam("b", Kind.BaseType) + x2 = TypeParam("x", Kind.Shape) + y2 = TypeParam("y", Kind.Shape) + + b3 = TypeParam("b", Kind.BaseType) + x3 = TypeParam("x", Kind.Shape) + y3 = TypeParam("y", Kind.Shape) + + tq1 = nest_quantifiers( + [b1, x1, y1], + TypeArrow( + [TensorType(b1, x1), TensorType(b1, y2)], + TensorType( + b1, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x1, ShapeProjection(y1, 1), + ShapeSingleton(5), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq2 = nest_quantifiers( + [b2, x2, y2], + TypeArrow( + [TensorType(b2, x2), TensorType(b2, y2)], + TensorType( + b2, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x2, ShapeProjection(y2, 1), + ShapeSingleton(5), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + # different attr, var order, position, and constant + tq3 = nest_quantifiers( + [b3, x3, y3], + TypeArrow( + [TensorType(b3, x3), TensorType(b3, y3)], + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x3, ShapeProjection(y3, 1), + ShapeSingleton(4), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq4 = nest_quantifiers( + [b3, x3, y3], + TypeArrow( + [TensorType(b3, x3), TensorType(b3, y3)], + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x3, ShapeProjection(y3, 2), + ShapeSingleton(5), ShapeAttr("att2")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq5 = nest_quantifiers( + [b3, x3, y3], + TypeArrow( + [TensorType(b3, x3), TensorType(b3, y3)], + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHMUL, + ShapeSeq([x3, ShapeProjection(y3, 1), + ShapeSingleton(5), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq6 = nest_quantifiers( + [b3, y3, x3], + TypeArrow( + [TensorType(b3, x3), TensorType(b3, y3)], + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x3, ShapeProjection(y3, 1), + ShapeSingleton(5), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + assert alpha_eq(tq1, tq2) + assert not alpha_eq(tq1, tq3) + assert not alpha_eq(tq2, tq3) + assert not alpha_eq(tq1, tq4) + assert not alpha_eq(tq2, tq4) + assert not alpha_eq(tq1, tq5) + assert not alpha_eq(tq2, tq5) + assert not alpha_eq(tq1, tq6) + assert not alpha_eq(tq2, tq6) From 2d83e48aa6e1519605b9f1e3113e9f71dca16aac Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 00:06:44 -0700 Subject: [PATCH 12/88] Add incomplete_type.h --- src/relay/compiler/incomplete_type.h | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 src/relay/compiler/incomplete_type.h diff --git a/src/relay/compiler/incomplete_type.h b/src/relay/compiler/incomplete_type.h new file mode 100644 index 000000000000..8f360d1cd51c --- /dev/null +++ b/src/relay/compiler/incomplete_type.h @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file incomplete_type.h + * \brief A way to defined arbitrary function signature with dispatch on types. + */ + +#ifndef TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H +#define TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H + +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +/*! + * \brief Represents a portion of an incomplete type. + */ +class IncompleteType; + +/*! \brief IncompleteType container node */ +class IncompleteTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) final {} + + TVM_DLL static IncompleteType make(); + + static constexpr const char* _type_key = "relay.IncompleteType"; + TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H From ad068cd6c29fd98697e4ddb7ee1c18f9395a4641 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 00:07:05 -0700 Subject: [PATCH 13/88] Add type call --- include/tvm/relay/type.h | 33 ++++++++++++++++++++++++++++++++- src/relay/type.cc | 19 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 4c6995646114..dfe4309b7c77 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -228,7 +228,38 @@ class TypeFunctionNode : public RelayNode { TVM_DECLARE_NODE_TYPE_INFO(TypeFunctionNode, RelayNode); }; -RELAY_DEFINE_NODE_REF(TypeFunction, TypeFunctionNode, NodeRef); +RELAY_DEFINE_NODE_REF(TypeFunction, TypeFunctionNode, Type); + +/*! + * \brief Call a type function with some number of arguments. + */ +class TypeCall; +/*! + * \brief TypeCall container. + */ +class TypeCallNode : public TypeNode { + public: + /*! \brief The type function to be called. */ + Type func; + /*! \brief The type arguments to the type function. */ + tvm::Array args; + + TypeCallNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("args", &args); + } + + Type eval() const; + + TVM_DLL static TypeCall make(Type func, tvm::Array args); + + static constexpr const char* _type_key = "relay.TypeCall"; + TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); // The following fields contains advanced typing // Only keep the class name and reserved for future usage. diff --git a/src/relay/type.cc b/src/relay/type.cc index 156207e1b73a..22d37ea05fda 100644 --- a/src/relay/type.cc +++ b/src/relay/type.cc @@ -96,5 +96,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TypeFunctionNode(" << node->name << ", " << node->num_args << ")"; }); +TypeCall TypeCallNode::make(Type func, Array args) { + std::shared_ptr n = std::make_shared(); + n->func = std::move(func); + n->args = std::move(args); + return TypeCall(n); +} + +TVM_REGISTER_API("relay._make.TypeCall") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TypeCallNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeCallNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; + }); + + } // namespace relay } // namespace tvm From 1ddfe128c04b0f43907b312fa2ced043039f0173 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 13:23:13 -0700 Subject: [PATCH 14/88] Add test for let with IR builder --- python/tvm/relay/ir_builder.py | 104 ++++++++++++++++++++++++++ tests/python/relay/test_ir_builder.py | 23 ++++++ 2 files changed, 127 insertions(+) create mode 100644 python/tvm/relay/ir_builder.py create mode 100644 tests/python/relay/test_ir_builder.py diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py new file mode 100644 index 000000000000..497479140ec9 --- /dev/null +++ b/python/tvm/relay/ir_builder.py @@ -0,0 +1,104 @@ +from typing import Any +import numpy as np +import tvm +from . import type as ty +from . import expr +from . import make as mk + + +def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: + """Convert Python values into the appropriate types + for the Relay evaluator. + """ + if isinstance(arg, int): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, float): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, bool): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, np.ndarray): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, tvm.ndarray.NDArray): + return arg + else: + # raise Exception(f"can't convert {type(arg)} to a Relay AST") + raise Exception(f"unsupported argument type {type(arg)}") + +def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> expr.Expr: + if isinstance(arg, tuple): + raise Exception("..") + else: + value = convert(arg, ctxt) + return mk.Constant(value) + +class WithScope(object): + """Auxiliary scope with""" + + def __init__(self, enter_value, exit_cb): + self._enter_value = enter_value + self._exit_cb = exit_cb + + def __enter__(self): + return self._enter_value + + def __exit__(self, ptype, value, trace): + self._exit_cb() + +def _mk_let(bindings, ret_value): + let_expr = ret_value + for var, value in reversed(list(bindings.items())): + let_expr = mk.Let(var, value, let_expr, None) + + return let_expr + +class IRBuilder(): + def __init__(self): + self.bindings = [{}] + self.scopes = [{}] + self.ret_value = None + + def bind(self, name, type, value): + lv = mk.LocalVar(name) + self.scopes[-1][name] = lv + self.bindings[-1][lv] = value + return lv + + + def let(self, name, value): + if not isinstance(value, expr.Expr): + value = into_ast(value) + + return self.bind(name, None, value) + + def function(self, params): + def _on_exit(): + bindings = self.bindings.pop() + scope = self.scopes.pop() + import pdb + pdb.set_trace() + return WithScope(None, _on_exit) + + def ret(self, x): + if not self.ret_value: + self.ret_value = x + else: + raise Exception( + "return value already set, a function can only have one return value") + + def get(self): + """Get the full program""" + bindings = self.bindings.pop() + scope = self.scopes.pop() + + if self.bindings: + raise Exception("...") + if self.scopes: + raise Exception("...") + + if not self.ret_value: + raise Exception("...") + + return _mk_let(bindings, self.ret_value) + + + diff --git a/tests/python/relay/test_ir_builder.py b/tests/python/relay/test_ir_builder.py new file mode 100644 index 000000000000..666d7ff25659 --- /dev/null +++ b/tests/python/relay/test_ir_builder.py @@ -0,0 +1,23 @@ +import numpy as np +from tvm.relay.expr import Let, Constant +from tvm.relay.ir_builder import IRBuilder + +def test_let(): + b = IRBuilder() + x = b.let('x', 1) + b.ret(x) + prog = b.get() + assert isinstance(prog, Let) + var = prog.var + value = prog.value + assert var.name_hint == 'x' + assert var == prog.body + assert isinstance(value, Constant) + assert value.data.asnumpy() == np.array(1) + assert prog.value_type == None + +# def test_function(): +# b = IRBuilder() + +if __name__ == "__main__": + test_let() From 3ecbc2bbe268c7e175e56cf9e9a65b1715f80641 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 13:24:43 -0700 Subject: [PATCH 15/88] Add initial version of unifier and old tests --- src/relay/compiler/unifier.cc | 477 ++++++++++++++++++++++++++++ src/relay/compiler/unifier.h | 129 ++++++++ tests/python/relay/test_unifier.py | 480 +++++++++++++++++++++++++++++ 3 files changed, 1086 insertions(+) create mode 100644 src/relay/compiler/unifier.cc create mode 100644 src/relay/compiler/unifier.h create mode 100644 tests/python/relay/test_unifier.py diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc new file mode 100644 index 000000000000..bfd3e1a5ff32 --- /dev/null +++ b/src/relay/compiler/unifier.cc @@ -0,0 +1,477 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file unifier.cc + * \brief Data structures for type unification + */ + +#include "tvm/relay/ir.h" +#include "tvm/relay/logging.h" +#include "tvm/relay/compiler/alpha_eq.h" +#include "./unifier.h" +#include "./type_visitor.h" +// #include "tvm/relay/typeck/kindchecker.h" +// #include "tvm/relay/typeck/type_subst.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +UnionFind UnionFindNode::make(tvm::Map uf_map) { + std::shared_ptr n = std::make_shared(); + n->uf_map = uf_map; + return UnionFind(n); +} + +void UnionFindNode::insert(const IncompleteType &v) { this->uf_map.Set(v, v); } + +void UnionFindNode::debug() { + for (auto entry : this->uf_map) { + std::cout << entry.first << " = " << entry.second << std::endl; + } +} + +void UnionFindNode::assertAlphaEq(const Type & l, const Type & r) { + if (!alpha_eq(l, r)) { + std::stringstream ss; + ss << "Incompatible parent types in UF:" << l << " and " << r; + throw UnionFindError(ss.str()); + } +} + +void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { + RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << "t=" << t << std::endl; + auto parent1 = this->find(v1); + + // if t is a type var, then unify parents + const IncompleteTypeNode *tvn2 = t.as(); + if (tvn2) { + auto v2 = GetRef(tvn2); + auto parent2 = this->find(v2); + + // if parents are exactly equal, then we're done + if (parent1 == parent2) { + return; + } + + // if first parent is a type var, then can just set its union find map to + // second parent + if (const IncompleteTypeNode *pvn1 = parent1.as()) { + auto pv1 = GetRef(pvn1); + this->uf_map.Set(pv1, parent2); + // path compression: can also set v1 directly + this->uf_map.Set(v1, parent2); + return; + } + + // if second parent is a type var but first isn't, can set second type var + if (const IncompleteTypeNode *pvn2 = parent2.as()) { + auto pv2 = GetRef(pvn2); + this->uf_map.Set(pv2, parent1); + // path compression: can also set v2 directly + this->uf_map.Set(v2, parent1); + return; + } + + // if both parents are not type vars themselves, check alpha-equality + assertAlphaEq(parent1, parent2); + return; + } + + // if t is not a type var, then unify with v1's parent if parent is a type + // var; else, check alpha-equality for compatibility + if (const IncompleteTypeNode *pvn1 = parent1.as()) { + auto pv1 = GetRef(pvn1); + this->uf_map.Set(pv1, t); + // path compression: can also set v1 directly + this->uf_map.Set(v1, t); + return; + } + + assertAlphaEq(parent1, t); +} + +Type UnionFindNode::find(const IncompleteType &v) { + // The node has no mapping, so its representative is just itself. + if (this->uf_map.find(v) == this->uf_map.end()) { + return v; + } + + Type parent = this->uf_map.at(v); + + if (v == parent) { + return v; + } + + // if parent is not a type var, then it must be the representative type + const IncompleteTypeNode *rep = parent.as(); + if (!rep) { + return parent; + } + + // otherwise, recurse and perform path compression + IncompleteType pv = GetRef(rep); + Type higher_up = this->find(pv); + this->uf_map.Set(v, higher_up); + return higher_up; +} + +TVM_REGISTER_API("relay._make.UnionFind") + .set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() == 0) { + *ret = UnionFindNode::make({}); + } else { + *ret = UnionFindNode::make(args[0]); + } + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const UnionFindNode *node, + tvm::IRPrinter *p) { + p->stream << "UnionFindNode(" << node->uf_map << ")"; + }); + +TypeUnifier TypeUnifierNode::make(UnionFind uf) { + std::shared_ptr n = std::make_shared(); + n->uf = uf; + return TypeUnifier(n); +} + +void TypeUnifierNode::insert(const IncompleteType &v) { this->uf->insert(v); } + +Type TypeUnifierNode::unify(const Type &t1, const Type &t2) { + RELAY_LOG(INFO) << "TypeUnifierNode::unify: t1=" << t1 << " t2=" << t2 + << std::endl; + + Type unified = this->VisitType(t1, t2); + // if (!check_kind(unified)) { + // throw UnificationError("Invalid kinds in unified type"); + // } + return unified; +} + +struct IncompleteTypeSubst : TypeFVisitor { + const TypeUnifierNode *unifier; + + IncompleteTypeSubst(const TypeUnifierNode *unifier) : unifier(unifier) {} + + // type var: look it up in the type map and recurse + Type VisitType_(const IncompleteTypeNode *op) override { + auto tv = GetRef(op); + auto parent = unifier->uf->find(tv); + if (parent == tv) { + return tv; + } + return this->VisitType(parent); + } +}; + +Type TypeUnifierNode::subst(const Type &t) { + IncompleteTypeSubst tvsubst(this); + // normalize first so substitutions in quantifiers will be correct + Type ret = tvsubst.VisitType(t); + // if (!check_kind(ret)) { + // std::stringstream ss; + // ss << "Invalid Kinds in substituted type!"; + // ss << t << std::endl; + // ss << ret << std::endl; + // throw SubstitutionError(ss.str()); + // } + return ret; +} + +Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { + IncompleteType tv1 = GetRef(t1); + RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2 + << std::endl; + this->uf->unify(tv1, rt2); + auto rep = this->uf->find(tv1); + RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode rep=" << rep << std::endl; + return rep; +} + +Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { + TypeParam ti1 = GetRef(t1); + + // for typevars, remap and attempt to unify if already defined + if (const IncompleteTypeNode *tvn2 = rt2.as()) { + return this->unifyWithIncompleteType(ti1, GetRef(tvn2)); + } + + // for other type ids, only check equality + if (const TypeParamNode *tin2 = rt2.as()) { + TypeParam ti2 = GetRef(tin2); + + if (ti1 != ti2) { + throw UnificationError("Attempting to unify non-matching TypeParams"); + } + + return ti1; + } + + // cannot unify TypeParam with non-TypeParam + throw UnificationError("Unable to unify TypeParamNode"); +} + +Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { + return rt2; +// TypeArrow ta1 = GetRef(t1); + +// // for typevar, remap if necessary +// if (const IncompleteTypeNode *tvn2 = rt2.as()) { +// return this->unifyWithIncompleteType(ta1, GetRef(tvn2)); +// } + +// // for other arrow, unify arg and ret types +// if (const TypeArrowNode *tan2 = rt2.as()) { +// TypeArrow ta2 = GetRef(tan2); + +// if (ta1->arg_types.size() != ta2->arg_types.size()) { +// throw UnificationError("unable to unify functions of different arities"); +// } + +// tvm::Array unified_args; +// for (size_t i = 0; i < ta1->arg_types.size(); i++) { +// unified_args.push_back( +// this->VisitType(ta1->arg_types[i], ta2->arg_types[i])); +// } + +// Type unified_ret_type = this->VisitType(ta1->ret_type, ta2->ret_type); +// return TypeArrowNode::make(unified_args, unified_ret_type); +// } + +// throw UnificationError("Unable to unify TypeArrowNode"); +// } + +// Type TypeUnifierNode::VisitType_(const TypeQuantifierNode *t1, const Type rt2) { +// TypeQuantifier tq1 = GetRef(t1); + +// // for typevars, remap and attempt to unify if already defined +// if (const IncompleteTypeNode *tvn2 = rt2.as()) { +// return this->unifyWithIncompleteType(tq1, GetRef(tvn2)); +// } + +// // for other quantifiers, attempt to unify bound types after normalizing +// if (const TypeQuantifierNode *tqn2 = rt2.as()) { +// TypeQuantifier tq2 = GetRef(tqn2); +// TypeParam id1 = tq1->id; +// TypeParam id2 = tq2->id; + +// if (id1->kind != id2->kind) { +// throw UnificationError( +// "Cannot unify quantifiers over ids of different kinds"); +// } + +// TypeParam fresh = TypeParamNode::make(id1->name, id1->kind); + +// auto bt1 = type_subst(tq1->boundType, id1, fresh); +// auto bt2 = type_subst(tq2->boundType, id2, fresh); + +// Type unified_bound_type = this->VisitType(bt1, bt2); +// return TypeQuantifierNode::make(fresh, unified_bound_type); +// } + +// // anything else cannot be unified +// throw UnificationError("Cannot unify TypeQuantifierNode"); +} + +Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { + TensorType tt1 = GetRef(t1); + + // for typevars, remap and attempt to unify if already defined + if (const IncompleteTypeNode *tvn2 = rt2.as()) { + return this->unifyWithIncompleteType(tt1, GetRef(tvn2)); + } + + if (const TensorTypeNode *ttn2 = rt2.as()) { + TensorType tt2 = GetRef(ttn2); + + if (!alpha_eq(tt1, tt2)) { + throw UnificationError("dtypes do not match"); + } + + RELAY_LOG(INFO) << "Unify Tensor Shape s1=" << tt1->shape + << " s2= " << tt2->shape << std::endl; + try { + // Type unified_shape = this->VisitType(tt1->shape, tt2->shape); + return rt2; + } catch (const UnificationError & err) { + std::cout << "Need to check constraint " << tt1->shape << " = " << tt2->shape << std::endl; + } + + // fix me + return rt2; + // return TensorTypeNode::make(unified_bt, tt2->shape); + } + + // nothing else can unify + throw UnificationError("Cannot unify TensorTypeNode"); +} + +// Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { +// TupleType pt1 = GetRef(t1); + +// // for typevar, remap and attempt to unify if already defined +// if (const IncompleteTypeNode *tvn2 = rt2.as()) { +// return this->unifyWithIncompleteType(pt1, GetRef(tvn2)); +// } + +// // for other product types, unify item by item +// if (const TupleTypeNode *ptn2 = rt2.as()) { +// TupleType pt2 = GetRef(ptn2); + +// std::vector unified_fields; +// if (pt1->fields.size() != pt2->fields.size()) { +// throw UnificationError("Product types are of different dimensions"); +// } + +// for (size_t i = 0U; i < pt1->fields.size(); i++) { +// Type unified = this->VisitType(pt1->fields[i], pt2->fields[i]); +// unified_fields.push_back(unified); +// } + +// return TupleTypeNode::make(unified_fields); +// } + +// // otherwise cannot unify +// throw UnificationError("Cannot unify TupleTypeNode"); +// } + +Type TypeUnifierNode::VisitType_(const TypeFunctionNode *sen1, const Type t2) { +// ShapeExtension sh_ext1 = GetRef(sen1); + +// if (const IncompleteTypeNode *tvn2 = t2.as()) { +// return this->unifyWithIncompleteType(sh_ext1, GetRef(tvn2)); +// } + +// // will only attempt to unify with binary op with same op +// if (const ShapeExtensionNode *sen2 = t2.as()) { +// if (sh_ext1->name != sen2->name) { +// throw UnificationError( +// "Cannot unify shape projections of different index"); +// } +// } + +// return sh_ext1; + return t2; +} + +Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { + TypeCall ty_call1 = GetRef(tcn1); + + if (const IncompleteTypeNode *tvn2 = t2.as()) { + return this->unifyWithIncompleteType(ty_call1, GetRef(tvn2)); + } + + if (const TypeCallNode *tcn2 = t2.as()) { + Type unified_func = this->VisitType(ty_call1->func, tcn2->func); + + // For now, we will only unify if they are equal. + if (ty_call1->args.size() != tcn2->args.size()) { + throw UnificationError("Cannot unify calls of different number of arguments"); + } + + // Unify members, if possible + tvm::Array new_args; + for (size_t i = 0U; i < ty_call1->args.size(); i++) { + Type unified_member = this->VisitType(ty_call1->args[i], tcn2->args[i]); + new_args.push_back(unified_member); + } + + return TypeCallNode::make(unified_func, new_args); + } else { + throw UnificationError("Cannot unify call with non-call"); + } +} + +Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, const IncompleteType tv2) { + RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; + // Fix unify to return new representative + this->uf->unify(tv2, t1); + auto rep = this->uf->find(tv2); + RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; + return rep; +} + +TVM_REGISTER_API("relay._make.TypeUnifier") + .set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() < 3) { + *ret = TypeUnifierNode::make(UnionFindNode::make({})); + } else { + *ret = TypeUnifierNode::make(args[0]); + } + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeUnifierNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeUnifierNode(" << node->uf << ")"; + }); + +TVM_REGISTER_API("relay._unifier.UnionFind_insert") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + UnionFind uf = args[0]; + uf->insert(args[1]); + } catch (std::exception &e) { + throw UnionFindError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.UnionFind_unify") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + UnionFind uf = args[0]; + uf->unify(args[1], args[2]); + } catch (std::exception &e) { + throw UnionFindError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.UnionFind_find") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + UnionFind uf = args[0]; + *ret = uf->find(args[1]); + } catch (std::exception &e) { + throw UnionFindError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.TypeUnifier_insert") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + TypeUnifier unifier = args[0]; + IncompleteType var = args[1]; + unifier->insert(var); + } catch (std::exception &e) { + throw UnificationError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.TypeUnifier_unify") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + TypeUnifier unifier = args[0]; + Type t1 = args[1]; + Type t2 = args[2]; + *ret = unifier->unify(t1, t2); + } catch (std::exception &e) { + throw UnificationError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.TypeUnifier_subst") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + TypeUnifier unifier = args[0]; + Type t = args[1]; + *ret = unifier->subst(t); + } catch (std::exception &e) { + throw SubstitutionError(e.what()); + } + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/compiler/unifier.h b/src/relay/compiler/unifier.h new file mode 100644 index 000000000000..6788265c90f2 --- /dev/null +++ b/src/relay/compiler/unifier.h @@ -0,0 +1,129 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file unifier.h + * \brief The type unifier which solves a system of equations between + * incomplete types. + */ +#ifndef TVM_RELAY_COMPILER_UNIFIER_H_ +#define TVM_RELAY_COMPILER_UNIFIER_H_ + +#include +#include "./type_functor.h" +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +struct UnionFindError : dmlc::Error { + explicit UnionFindError(const std::string& msg) : Error(msg) {} +}; + +struct UnificationError : dmlc::Error { + explicit UnificationError(const std::string& msg) : Error(msg) {} +}; + +struct SubstitutionError : dmlc::Error { + explicit SubstitutionError(const std::string& msg) : Error(msg) {} +}; + +/*! \brief a union-find data structure for the type-checker */ +class UnionFind; // forward declaration + +class UnionFindNode : public Node { + public: + tvm::Map uf_map; + + UnionFindNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf_map", &uf_map); } + + TVM_DLL static UnionFind make(tvm::Map uf_map); + + // insert v into UF + void insert(const IncompleteType& v); + + // infers that v1 and v2 must be of the smae type + void unify(const IncompleteType& v1, const Type& v2); + + // returns representative of v's UF-group + Type find(const IncompleteType& v); + + void debug(); + + void assertAlphaEq(const Type& l, const Type& r); + + static constexpr const char* _type_key = "relay.UnionFind"; + TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, Node); +}; + +class UnionFind : public NodeRef { + public: + UnionFind() {} + explicit UnionFind(std::shared_ptr p) : NodeRef(p) {} + + // no const so that union find can be mutable as a member of unifier + inline UnionFindNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = UnionFindNode; +}; + +class TypeUnifier; +class TypeUnifierNode : public Node, + private TypeFunctor { + public: + UnionFind uf; + + TypeUnifierNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf", &uf); } + + TVM_DLL static TypeUnifier make(UnionFind uf); + + /*! \brief Introduces a new type var into the unifier */ + void insert(const IncompleteType& v); + + /*! \brief Unifies two types if possible, throws a unification error if it + * cannot */ + Type unify(const Type& t1, const Type& t2); + + /*! \brief Attempts to substitute all type vars in t with concrete types, + * throws substitution error if it cannot concretize*/ + Type subst(const Type& t); + + // /*! \brief Checks the kinds in the given type */ + // Type CheckKinds(const Type& t); + + static constexpr const char* _type_key = "relay.TypeUnifier"; + TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); + + private: + // unify non-typevar with typevar + Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); + + Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; + Type VisitType_(const TensorTypeNode* t1, const Type t2) override; + Type VisitType_(const TypeParamNode* t1, const Type t2) override; + Type VisitType_(const FuncTypeNode* t1, const Type t2) override; + // Type VisitType_(const TupleTypeNode* t1, const Type t2) override; + Type VisitType_(const TypeFunctionNode* s1, const Type t2) override; + Type VisitType_(const TypeCallNode* s1, const Type t2) override; +}; + +class TypeUnifier : public NodeRef { + public: + TypeUnifier() {} + explicit TypeUnifier(std::shared_ptr p) : NodeRef(p) {} + + // no const so that unifier can be mutable as a member of typechecker + inline TypeUnifierNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = TypeUnifierNode; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPECK_UNIFIER_H_ diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py new file mode 100644 index 000000000000..7938a5a3ae5f --- /dev/null +++ b/tests/python/relay/test_unifier.py @@ -0,0 +1,480 @@ +"""Tests unification of types.""" +# pylint: disable=invalid-name, missing-docstring, bare-except +import relay.ir +# pylint: disable=unused-import +import relay.unifier # TODO (@jroesch) fix me +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * + +def unify_types(t1, t2): + unifier = TypeUnifier() + return unifier.unify(t1, t2) + +def int_type(): + return TensorType(IntType(32), ShapeSeq([])) + +def float_type(): + return TensorType(FloatType(32), ShapeSeq([])) + +def bool_type(): + return TensorType(BoolType(), ShapeSeq([])) + +def make_shape(dims): + return ShapeSeq([ShapeSingleton(dim) for dim in dims]) + +def test_insert_and_find(): + uf = UnionFind() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + uf.insert(v1) + uf.insert(v2) + assert uf.find(v1) == v1 + assert uf.find(v2) == v2 + +def test_insert_error(): + uf = UnionFind() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + uf.insert(v1) + try: + uf.find(v2) + assert False + except: + return + +def test_unify(): + uf = UnionFind() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + v3 = TypeVar(ir.Kind.Type) + uf.insert(v1) + uf.insert(v2) + uf.insert(v3) + uf.unify(v1, v2) + rep = uf.find(v1) + assert (rep == v1 or rep == v2) + assert uf.find(v1) == rep + assert uf.find(v2) == rep + assert uf.find(v3) == v3 + assert v3 != rep + uf.unify(v1, v3) + new_rep = uf.find(v3) + assert (rep == v1 or rep == v2 or rep == v3) + assert uf.find(v1) == new_rep + assert uf.find(v2) == new_rep + assert uf.find(v3) == new_rep + +def test_unify_multiple_levels(): + uf = UnionFind() + v = [TypeVar(ir.Kind.Type) for _ in range(9)] + for var in v: + uf.insert(var) + uf.unify(v[0], v[1]) + uf.unify(v[0], v[2]) + uf.unify(v[3], v[4]) + uf.unify(v[4], v[5]) + uf.unify(v[6], v[7]) + uf.unify(v[6], v[8]) + rep1 = uf.find(v[0]) + rep2 = uf.find(v[3]) + rep3 = uf.find(v[6]) + assert (rep1 == v[0] or rep1 == v[1] or rep1 == v[2]) + assert (rep2 == v[3] or rep2 == v[4] or rep2 == v[5]) + assert (rep3 == v[6] or rep3 == v[7] or rep3 == v[8]) + for i in range(3): + assert uf.find(v[i]) == rep1 + assert uf.find(v[i + 3]) == rep2 + assert uf.find(v[i + 6]) == rep3 + # now unify two of the groups + uf.unify(v[1], v[4]) + new_rep1 = uf.find(v[0]) + new_rep2 = uf.find(v[6]) + assert (new_rep1 == v[0] or new_rep1 == v[1] or new_rep1 == v[2] + or new_rep1 == v[3] or new_rep1 == v[4] or new_rep1 == v[5]) + assert (new_rep2 == v[6] or new_rep2 == v[7] or new_rep2 == v[8]) + for i in range(6): + assert uf.find(v[i]) == new_rep1 + for i in range(3): + assert uf.find(v[i + 6]) == new_rep2 + +# TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work +def test_unify_int(): + intty = IntType(1) + unified = unify_types(intty, intty) + assert intty == unified + +def test_unify_bool(): + boolty = BoolType() + unified = unify_types(boolty, boolty) + assert boolty == unified + +def test_unify_float(): + floatty = FloatType(4) + unified = unify_types(floatty, floatty) + assert floatty == unified + +def test_unify_incompatible_basetypes(): + bt = BoolType() + intty = IntType(32) + try: + unify_types(bt, intty) + assert False + except: + return + +def test_unify_concrete_type_arrow(): + arr1 = TypeArrow([int_type()], int_type()) + arr2 = TypeArrow([int_type()], int_type()) + unified = unify_types(arr1, arr2) + assert unified == arr1 + +def test_unify_type_arrow_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unifier.unify(v1, bool_type()) + arr1 = TypeArrow([int_type()], bool_type()) + arr2 = TypeArrow([int_type()], v1) + unified = unifier.unify(arr1, arr2) + assert unified == arr1 + + v2 = TypeVar(ir.Kind.BaseType) + unifier.insert(v2) + unifier.unify(v2, int_type()) + arr3 = TypeArrow([v2], bool_type()) + unified = unifier.unify(arr1, arr3) + assert unified == arr1 + +def test_reject_incompatible_type_arrows(): + arr1 = TypeArrow([int_type()], bool_type()) + arr2 = TypeArrow([int_type(), bool_type()], bool_type()) + try: + unify_types(arr1, arr2) + assert False + except: + return + +def test_unify_concrete_type_quantifiers(): + tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) + tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), int_type()) + unified = unify_types(tq1, tq2) + assert unified == tq1 + +def test_unify_basetype_with_quantifier_error(): + bt = bool_type() + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) + try: + unify_types(bt, tq) + assert False + except: + return + +def test_unify_typevars_with_each_other(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + v3 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + unifier.insert(v2) + unifier.insert(v3) + unified = unifier.unify(v1, v2) + assert (unified == v1 or unified == v2) + assert unified != v3 + new_unified = unifier.unify(v1, v3) + assert (new_unified == v1 or new_unified == v2 or new_unified == v3) + +def test_unify_typevars_with_basetype(): + unifier = TypeUnifier() + bt = BoolType() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unifier.insert(v2) + unified1 = unifier.unify(v1, bt) + assert unified1 == bt + unified2 = unifier.unify(v1, v2) + assert unified2 == bt + +def test_unify_compatible_typevars(): + unifier = TypeUnifier() + bt = BoolType() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v1, bt) + unifier.unify(v2, bt) + # because types to which v1 and v2 have been assigned are compatible, + # this should proceed without problems + unified = unifier.unify(v1, v2) + assert unified == bt + +def test_unify_incompatible_typevars(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.BaseType) + bt = bool_type() + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v1, bt) + unifier.unify(v2, tq) + # bt cannot be unified with tq, so unifying v1 and v2 should give an error + try: + unifier.unify(v1, v2) + assert False + except: + return + +def test_unify_typevar_with_quantifier(): + unifier = TypeUnifier() + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) + v1 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unified = unifier.unify(v1, tq) + assert unified == tq + +def test_unify_typevars_inside_concrete_quantifier(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) + tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) + unified = unifier.unify(tq1, tq2) + assert unified == tq2 + +def test_unify_concrete_tensors(): + bt = BoolType() + shape = make_shape([1, 2, 3]) + tt1 = TensorType(bt, shape) + tt2 = TensorType(bt, shape) + unified = unify_types(tt1, tt2) + assert unified == tt1 + +def test_unify_tensor_shape_reject(): + bt = BoolType() + shape1 = make_shape([1, 2, 3]) + shape2 = make_shape([2, 3, 4]) + tt1 = TensorType(bt, shape1) + tt2 = TensorType(bt, shape2) + try: + unify_types(tt1, tt2) + assert False + except: + return + +def test_unify_tensor_dtype_reject(): + bt1 = BoolType() + bt2 = IntType(32) + shape = make_shape([1, 2, 3]) + tt1 = TensorType(bt1, shape) + tt2 = TensorType(bt2, shape) + try: + unify_types(tt1, tt2) + assert False + except: + return + +def test_unify_quantified_tensors(): + x = TypeParam("x", ir.type.Kind.Shape) + y = TypeParam("y", ir.type.Kind.Shape) + tq1 = TypeQuantifier(x, TensorType(BoolType(), x)) + tq2 = TypeQuantifier(y, TensorType(BoolType(), y)) + unified = unify_types(tq1, tq2) + assert unified == tq1 + + a = TypeParam("a", ir.type.Kind.BaseType) + b = TypeParam("b", ir.type.Kind.BaseType) + tq3 = TypeQuantifier(a, TensorType(a, make_shape([1, 2, 3]))) + tq4 = TypeQuantifier(b, TensorType(b, make_shape([1, 2, 3]))) + unified = unify_types(tq3, tq4) + assert unified == tq3 + +def test_unify_concrete_products(): + bt = bool_type() + intty = int_type() + pt1 = TupleType([bt, intty]) + pt2 = TupleType([bt, intty]) + unified = unify_types(pt1, pt2) + assert unified == pt1 + +def test_unify_products_reject_size(): + bt = BoolType() + intty = IntType(32) + pt1 = TupleType([bt, bt, intty]) + pt2 = TupleType([bt, intty]) + try: + unify_types(pt1, pt2) + assert False + except: + return + +def test_unify_products_reject_member(): + bt = BoolType() + intty = IntType(32) + pt1 = TupleType([bt, bt]) + pt2 = TupleType([bt, intty]) + try: + unify_types(pt1, pt2) + assert False + except: + return + +def test_unify_products_typevar(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + bt = bool_type() + pt1 = TupleType([bt, bt]) + pt2 = TupleType([v1, bt]) + unifier.insert(v1) + unified = unifier.unify(pt1, pt2) + assert unified == pt1 + +def test_unify_quantified_products(): + x = TypeParam("x", ir.Kind.Type) + y = TypeParam("y", ir.Kind.Type) + p1 = TypeQuantifier(x, TupleType([int_type(), x])) + p2 = TypeQuantifier(y, TupleType([int_type(), y])) + unified = unify_types(p1, p2) + assert unified == p1 + +def test_unify_ref_types(): + r1 = RefType(bool_type()) + r2 = RefType(bool_type()) + assert unify_types(r1, r2) == r1 + +def test_unify_ref_reject_inner(): + r1 = RefType(BoolType()) + r2 = RefType(IntType(32)) + try: + unify_types(r1, r2) + assert False + except: + return + +def test_subst_basetype(): + unifier = TypeUnifier() + bt = BoolType() + assert bt == unifier.subst(bt) + +def test_subst_simple_hole(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + bt = BoolType() + unifier.insert(v1) + unifier.unify(v1, bt) + assert unifier.subst(v1) == bt + +def test_subst_typevar_for_typevar(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + unifier.insert(v2) + + unifier.unify(v1, v2) + assert unifier.subst(v1) == v2 + +def test_subst_concrete_arrow(): + unifier = TypeUnifier() + arr1 = TypeArrow([int_type()], int_type()) + assert unifier.subst(arr1) == arr1 + +def test_subst_arrow_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v1, int_type()) + unifier.unify(v2, bool_type()) + arr1 = TypeArrow([v1], v2) + arr2 = TypeArrow([int_type()], bool_type()) + assert unifier.subst(arr1) == arr2 + +def test_subst_concrete_quantifier(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) + unifier.insert(v1) + unifier.unify(v1, tq) + assert unifier.subst(v1) == tq + +def test_subst_quantifier_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) + intty = int_type() + tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) + + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v2, intty) + unifier.unify(v1, tq1) + assert unifier.subst(v1) == tq2 + +def test_subst_concrete_tensor(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + tt = TensorType(BoolType(), make_shape([1, 2, 3])) + unifier.unify(v1, tt) + assert unifier.subst(v1) == tt + +def test_subst_concrete_product(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + bt = bool_type() + pt = TupleType([bt, bt]) + unifier.unify(v1, pt) + assert unifier.subst(v1) == pt + +def test_subst_product_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + v3 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + unifier.insert(v2) + unifier.insert(v3) + + tt1 = TensorType(IntType(32), ShapeSeq([])) + tt2 = TensorType(FloatType(32), ShapeSeq([])) + pt1 = TupleType([tt1, v2, v3]) + unifier.unify(v2, tt2) + unifier.unify(v3, v2) + unifier.unify(v1, pt1) + pt2 = TupleType([tt1, tt2, tt2]) + assert unifier.subst(v1) == pt2 + +def test_subst_concrete_ref(): + unifier = TypeUnifier() + rt = RefType(bool_type()) + assert unifier.subst(rt) == rt + +def test_subst_ref_with_hole(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + + unifier.unify(v1, bool_type()) + rt1 = RefType(v1) + rt2 = RefType(bool_type()) + assert unifier.subst(rt1) == rt2 + +def test_typevar_on_lhs(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.Type) + bt = bool_type() + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) + unifier.insert(v1) + unifier.insert(v2) + unified1 = unifier.unify(bt, v1) + assert unified1 == bt + unified2 = unifier.unify(tq, v2) + assert unified2 == tq + assert unifier.subst(v1) == bt + assert unifier.subst(v2) == tq From 73d457024e0045daca979b03f1d69c4018876411 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 13:25:20 -0700 Subject: [PATCH 16/88] Update type_functor.h for incomplete type. --- include/tvm/relay/compiler/typechecker.h | 2 +- src/relay/compiler/type_functor.h | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/compiler/typechecker.h b/include/tvm/relay/compiler/typechecker.h index c71f78c1a5b0..c69aba3c1e71 100644 --- a/include/tvm/relay/compiler/typechecker.h +++ b/include/tvm/relay/compiler/typechecker.h @@ -8,7 +8,7 @@ #define TVM_RELAY_COMPILER_TYPECHECKER_H_ #include "tvm/relay/ir.h" -#include "tvm/relay/environment.h" +#include "tvm/relay/compiler/environment.h" namespace tvm { namespace relay { diff --git a/src/relay/compiler/type_functor.h b/src/relay/compiler/type_functor.h index 66454725db48..3840c902bfe8 100644 --- a/src/relay/compiler/type_functor.h +++ b/src/relay/compiler/type_functor.h @@ -7,7 +7,8 @@ #define TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ #include -#include "ir.h" +#include "tvm/relay/ir.h" +#include "./incomplete_type.h" namespace tvm { namespace relay { @@ -61,12 +62,10 @@ class TypeFunctor { Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeFunction* op, Args... args) TYPE_FUNCTOR_DEFAULT; - Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeFunctionNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); @@ -84,6 +83,7 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeFunctionNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); return vtable; } }; From 7dc6a24c0d8b7658b351a2f89ad81c96d79ea3a4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 13:40:34 -0700 Subject: [PATCH 17/88] Add Python side of unifier --- python/tvm/relay/_unifier.py | 5 +++ python/tvm/relay/_unifier.pyi | 12 ++++++ python/tvm/relay/ir.py | 18 +++++++++ python/tvm/relay/type.py | 5 +++ python/tvm/relay/unifier.py | 61 ++++++++++++++++++++++++++++++ tests/python/relay/test_unifier.py | 12 +++--- 6 files changed, 107 insertions(+), 6 deletions(-) create mode 100644 python/tvm/relay/_unifier.py create mode 100644 python/tvm/relay/_unifier.pyi create mode 100644 python/tvm/relay/ir.py create mode 100644 python/tvm/relay/unifier.py diff --git a/python/tvm/relay/_unifier.py b/python/tvm/relay/_unifier.py new file mode 100644 index 000000000000..41f5fe374b3e --- /dev/null +++ b/python/tvm/relay/_unifier.py @@ -0,0 +1,5 @@ +"""FFI functions for the Unifier.""" + +from tvm._ffi.function import _init_api + +_init_api("relay._unifier", __name__) diff --git a/python/tvm/relay/_unifier.pyi b/python/tvm/relay/_unifier.pyi new file mode 100644 index 000000000000..6ecd309250a6 --- /dev/null +++ b/python/tvm/relay/_unifier.pyi @@ -0,0 +1,12 @@ +from tvm.relay.ir import NodeBase + +class UnionFind(NodeBase): ... +class TypeUnifier(NodeBase): ... + +def UnionFind_insert(self: UnionFind, var: ir.IncompleteType) -> None: ... +def UnionFind_unify(self: UnionFind, var1: ir.IncompleteType, var2: ir.IncompleteType) -> None: ... +def UnionFind_find(self: UnionFind, var: ir.IncompleteType) -> ir.Type: ... + +def TypeUnifier_insert(self: TypeUnifier, var: ir.IncompleteType) -> None: ... +def TypeUnifier_unify(self, type1: ir.Type, type2: ir.Type) -> ir.Type: ... +def TypeUnifier_subst(self, type1: ir.Type) -> ir.Type: ... diff --git a/python/tvm/relay/ir.py b/python/tvm/relay/ir.py new file mode 100644 index 000000000000..a95f29abe6de --- /dev/null +++ b/python/tvm/relay/ir.py @@ -0,0 +1,18 @@ +from . import base +from . import type as ty +from . import expr + +# Base +register_relay_node = base.register_relay_node +NodeBase = base.NodeBase + +# Type +Type = ty.Type +TensorType = ty.Type +Kind = ty.Kind +TypeParam = ty.TypeParam +TypeConstraint = ty.TypeConstraint +FuncType = ty.FuncType +IncompleteType = ty.IncompleteType + +# Expr diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index c92f0d756587..4d53cf88a218 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -49,3 +49,8 @@ class FuncType(Type): arg_types: List[Type] ret_type: Type span: Span + +@register_relay_node +class IncompleteType(Type): + """An incomplete type.""" + pass diff --git a/python/tvm/relay/unifier.py b/python/tvm/relay/unifier.py new file mode 100644 index 000000000000..cb818de19c1d --- /dev/null +++ b/python/tvm/relay/unifier.py @@ -0,0 +1,61 @@ +"""The Python interface to Relay's UnionFind and TypeUnifier.""" + +from typing import Dict +from .ir import register_relay_node, NodeBase +from . import ir +from . import _unifier + +@register_relay_node +class UnionFind(NodeBase): + """Python API for UnionFind. + + The UnionFind maintains equality classes of type variables, the + representative of an equality class may be a type (which can) + contain type variables. The TypeUnifier uses this to build a + unification procedure between types. + """ + uf_map: Dict[ir.IncompleteType, ir.IncompleteType] + + def insert(self, var: ir.IncompleteType) -> None: + """Insert a type variable into the union find. + + :param: var: The variable to be inserted. + """ + return _unifier.UnionFind_insert(self, var) + + def unify(self, var: ir.IncompleteType, typ: ir.Type) -> None: + """Unify a type variable with an arbitrary type. + + :param: var: A type variable to be unified. + :param: typ: The type to be unified with. + """ + return _unifier.UnionFind_unify(self, var, typ) + + def find(self, var: ir.IncompleteType) -> ir.IncompleteType: + """Find the representative element of the type var. + + :param: var: The variable to lookup in the union find. + """ + return _unifier.UnionFind_find(self, var) + +@register_relay_node +class TypeUnifier(NodeBase): + """Python API for the TypeUnifier.""" + #pylint: disable=invalid-name + uf: UnionFind + eq_map: Dict[ir.TypeParam, ir.TypeParam] + + def insert(self, var: ir.IncompleteType) -> None: + return _unifier.TypeUnifier_insert(self, var) + + def unify(self, type1: ir.Type, type2: ir.Type) -> ir.Type: + """Unify two types producing the unified type as a result. + + :param: type1: The first type to be unified. + :param: type2: The second type to be unified. + :returns: The unified type. + """ + return _unifier.TypeUnifier_unify(self, type1, type2) + + def subst(self, type1: ir.Type) -> ir.Type: + return _unifier.TypeUnifier_subst(self, type1) diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index 7938a5a3ae5f..875502808563 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -1,10 +1,10 @@ -"""Tests unification of types.""" -# pylint: disable=invalid-name, missing-docstring, bare-except +""" +Test the type unifier, which solves systems of equations +between incomplete types. +""" import relay.ir -# pylint: disable=unused-import -import relay.unifier # TODO (@jroesch) fix me -# pylint: disable=wildcard-import, unused-wildcard-import -from relay.make import * +import relay.unifier + def unify_types(t1, t2): unifier = TypeUnifier() From 5ce60e6b3160aed2b6a37b9ac42e8e5c4e59bed0 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:06:09 -0700 Subject: [PATCH 18/88] Add to incomplete_type and add impl in typechecker.cc --- include/tvm/relay/type.h | 5 +- python/tvm/relay/type.py | 4 +- src/relay/compiler/incomplete_type.h | 8 +- src/relay/compiler/typechecker.cc | 771 +++++++++++++++++++++++++++ 4 files changed, 783 insertions(+), 5 deletions(-) create mode 100644 src/relay/compiler/typechecker.cc diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index dfe4309b7c77..4eeb42168d68 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -113,7 +113,10 @@ class TypeParamNode : public TypeNode { /*! \brief possible kinds of TypeParam */ enum Kind : int { /*! \brief template variable in shape expression */ - kShapeVar = 0 + kShapeVar = 0, + kShape = 1, + kBaseType = 2, + kType = 3, }; /*! * \brief The variable diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index 4d53cf88a218..2790b546cfe5 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -21,10 +21,10 @@ class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, base type, type, or dimension. """ - Shape = 0 + ShapeVar = 0 + Shape = 1 BaseType = 1 Type = 2 - Elem = 3 @register_relay_node class TypeParam(Type): diff --git a/src/relay/compiler/incomplete_type.h b/src/relay/compiler/incomplete_type.h index 8f360d1cd51c..f31a2efdf78d 100644 --- a/src/relay/compiler/incomplete_type.h +++ b/src/relay/compiler/incomplete_type.h @@ -20,9 +20,13 @@ class IncompleteType; /*! \brief IncompleteType container node */ class IncompleteTypeNode : public TypeNode { public: - void VisitAttrs(tvm::AttrVisitor* v) final {} + TypeParamNode::Kind kind; - TVM_DLL static IncompleteType make(); + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("kind", &kind); + } + + TVM_DLL static IncompleteType make(TypeParamNode::Kind kind); static constexpr const char* _type_key = "relay.IncompleteType"; TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); diff --git a/src/relay/compiler/typechecker.cc b/src/relay/compiler/typechecker.cc new file mode 100644 index 000000000000..c1f7b7f88765 --- /dev/null +++ b/src/relay/compiler/typechecker.cc @@ -0,0 +1,771 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file typechecker.cc + * \brief Relay typechecker + */ + +#include "tvm/relay/compiler/typechecker.h" +#include "./incomplete_type.h" +// #include "tvm/relay/alpha_eq.h" +// #include "tvm/relay/debug.h" +// #include "tvm/relay/first_order_reverse_ad.h" +// #include "tvm/relay/free_type_vars.h" +// #include "tvm/relay/gen_fresh.h" +// #include "tvm/relay/ir.h" +// #include "tvm/relay/logging.h" +// #include "tvm/relay/pretty_printer.h" +// #include "tvm/relay/reverse_ad.h" +// #include "tvm/relay/type_visitor.h" +// #include "tvm/relay/typeck/kindchecker.h" +// #include "tvm/relay/typeck/resolve.h" +// #include "tvm/relay/typeck/shape_evaluator.h" + +namespace tvm { +namespace relay { + +// using namespace tvm::runtime; + +// struct FatalTypeError : dmlc::Error { +// explicit FatalTypeError(const std::string & s) : dmlc::Error(s) {} +// }; + +// struct TypeContext { +// std::vector> stack; +// TypeContext() { +// stack.push_back({}); +// } +// void insert(const LocalId &id, const Type &t) { stack.back()[id] = t; } +// Type lookup(const LocalId &id) { +// for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { +// if (frame->find(id) != frame->end()) { +// return frame->at(id); +// } +// } +// throw FatalTypeError("Could not resolve local id"); +// } +// struct LocalFrame { +// TypeContext & tc; +// explicit LocalFrame(TypeContext & tc) : tc(tc) { +// tc.stack.push_back({}); +// } +// ~LocalFrame() { +// tc.stack.pop_back(); +// } +// }; +// }; + +// class Typechecker : private ExprFunctor { +// private: +// TypeContext local_stack; +// public: +// Environment env; +// TypeUnifier unifier; + +// template +// T with_frame(const std::function & f) { +// TypeContext::LocalFrame fr(local_stack); +// return f(); +// } + +// Typechecker(); +// Typechecker(Environment env, TypeUnifier unifier) : env(env), unifier(unifier) {} +// explicit Typechecker(Environment env); +// Type Check(const Expr & expr); +// Type instantiate(Type t, tvm::Array & ty_args); + +// void report_error(const std::string & msg, Span sp); +// [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); + +// Type unify(const Type &t1, const Type &t2, Span sp); +// Type resolve(const Type &t); +// Expr resolve(const Expr &e); +// Type VisitFunction(const Function & f, bool generalize); +// Operator CheckOp(Operator op); +// Defn CheckDefn(Defn def); +// private: +// Type VisitExpr_(const LocalIdNode* op) override; +// Type VisitExpr_(const GlobalIdNode* op) override; +// Type VisitExpr_(const OperatorIdNode* op) override; +// Type VisitExpr_(const FloatLitNode* op) override; +// Type VisitExpr_(const BoolLitNode* op) override; +// Type VisitExpr_(const IntLitNode* op) override; +// Type VisitExpr_(const TensorLitNode* op) override; +// Type VisitExpr_(const TupleNode* op) override; +// Type VisitExpr_(const CastNode* op) override; +// Type VisitExpr_(const ParamNode* op) override; +// Type VisitExpr_(const FunctionNode* op) override; +// Type VisitExpr_(const CallNode* op) override; +// Type VisitExpr_(const DebugNode* op) override; +// Type VisitExpr_(const LetNode* op) override; +// Type VisitExpr_(const ReverseNode* op) override; +// Type VisitExpr_(const GradientNode* op) override; +// Type VisitExpr_(const ProjectionNode* op) override; +// Type VisitExpr_(const IfNode* op) override; +// Type VisitExpr_(const RefNode* op) override; +// Type VisitExpr_(const ReadRefNode* op) override; +// Type VisitExpr_(const WriteRefNode* op) override; +// Type simple_eval_shape(const Type &shape); +// }; +// struct TypecheckerError : public dmlc::Error { +// explicit TypecheckerError(const std::string &msg) : Error(msg) {} +// }; + +// Typechecker::Typechecker() { +// this->env = EnvironmentNode::make({}); +// this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +// } + +// Typechecker::Typechecker(Environment env) : env(env) { +// this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +// } + +// Type Typechecker::Check(const Expr &expr) { +// RELAY_LOG(INFO) << "Typechecker::Check expr=" << expr << std::endl; +// Type ret = this->VisitExpr(expr); +// RELAY_LOG(INFO) << "Typechecker::Check type=" << expr << std::endl; +// ret = this->unifier->subst(ret); +// RELAY_LOG(INFO) << "Typechecker::Check type_after_subst=" << ret << std::endl; +// expr->checked_type_ = ret; +// return ret; +// } + +// Type Typechecker::VisitExpr_(const LocalIdNode *op) { +// LocalId id = GetRef(op); +// return this->local_stack.lookup(id); +// } + +// Type Typechecker::VisitExpr_(const GlobalIdNode *op) { +// GlobalId id = GetRef(op); +// Item item = this->env->lookup(id); + +// if (const OperatorNode *op = item.as()) { +// return op->type; +// } + +// if (const DefnNode *dn = item.as()) { +// Defn def = GetRef(dn); +// return def->type; +// } + +// this->fatal_error("Unhandled case in GlobalId", op->span); +// } + +// Type Typechecker::VisitExpr_(const OperatorIdNode *op) { +// OperatorId id = GetRef(op); +// Item item = this->env->lookup(id); + +// if (const OperatorNode *pn = item.as()) { +// Operator prim = GetRef(pn); +// return prim->type; +// } else { +// this->fatal_error("internal error in InstrinsicId case", op->span); +// } +// } + +// Type Typechecker::VisitExpr_(const FloatLitNode *op) { return FloatType(); } + +// Type Typechecker::VisitExpr_(const BoolLitNode *op) { return BoolType(); } + +// Type Typechecker::VisitExpr_(const IntLitNode *op) { return IntType(); } + +// Type Typechecker::VisitExpr_(const TensorLitNode *op) { +// TensorLit lit = GetRef(op); + +// if (lit->data.size() == 0) { +// this->fatal_error("Tensor literal must have at least one member", op->span); +// } + +// // unify types of all members to figure out shape, also ensure that +// // each member has compatible shape +// Type unified = this->Check(lit->data[0]); +// for (auto elt = lit->data.begin(); elt != lit->data.end(); elt++) { +// // evaluate all shape ASTs so they can be in standard form +// // TODO(sslyu): eventually we'd want this to be symbolic evaluation +// auto elt_el = *elt; +// Type elt_type = simple_eval_shape(this->Check(*elt)); +// if (!elt_type.as()) { +// this->fatal_error("All members in tensor literal must be tensors", +// elt_el->span); +// } +// unified = this->unify(unified, elt_type, lit->span); +// } + +// // types must unify into a tensor +// const TensorTypeNode *ttn = unified.as(); +// // shouldn't be possible due to check inside the loop +// if (!ttn) { +// this->fatal_error("Tensor literal contains non-tensor member", op->span); +// } + +// TensorType unified_tt = GetRef(ttn); + +// // new shape: add length of this tensor to front of existing shape +// // i.e., sequence and simplify +// // TODO(sslyu): should be symbolic evaluation eventually? +// Type new_shape = ShapeSeqNode::make( +// {ShapeSingletonNode::make(lit->data.size()), unified_tt->shape}); +// return TensorTypeNode::make(unified_tt->dtype, simple_eval_shape(new_shape)); +// } + +// Type Typechecker::VisitExpr_(const TupleNode *op) { +// Tuple pl = GetRef(op); + +// std::vector field_types; +// for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) { +// field_types.push_back(this->Check(*field)); +// } + +// return TupleTypeNode::make(field_types); +// } + +// Type Typechecker::VisitExpr_(const CastNode *op) { +// // will take the cast at its word +// Cast cast = GetRef(op); +// return cast->target; +// } + +// Type Typechecker::VisitExpr_(const ParamNode *op) { +// Param param = GetRef(op); +// return resolve(param->type); +// } + +// // We should probably generalize the subst code. +// struct GeneralizeTypeType : TypeFVisitor { +// Map vars_to_id; +// const TypeUnifier &unifier; + +// GeneralizeTypeType(Map vars_to_id, +// const TypeUnifier &unifier) +// : vars_to_id(vars_to_id), unifier(unifier) {} + +// Type VisitType_(const TypeVarNode *op) override { +// auto repr = unifier->subst(GetRef(op)); +// if (auto tvn = repr.as()) { +// auto ty_var = GetRef(tvn); +// if (vars_to_id.find(ty_var) != vars_to_id.end()) { +// return vars_to_id[ty_var]; +// } else { +// return ty_var; +// } +// } else { +// return this->VisitType(repr); +// } +// } +// }; + +// struct GeneralizeTypeExpr : ExprFVisitor<> { +// Map vars_to_id; +// const TypeUnifier &unifier; + +// GeneralizeTypeExpr(const TypeUnifier &unifier, +// Map vars_to_id) +// : vars_to_id(vars_to_id), unifier(unifier) {} + +// Type VisitType(const Type &t) { +// return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); +// } +// }; + +// Type Typechecker::VisitFunction(const Function &f, bool generalize) { +// // enter params into context +// auto fn_type = this->with_frame([&]() { +// std::vector arg_types; +// for (auto arg : f->params) { +// this->Check(arg); +// Type arg_type; +// // if arg type can be simply evaluated, try it +// // should be replaced with symbolic evaluation once it exists, +// // you will not have attr information at this point +// try { +// arg_type = simple_eval_shape(arg->type); +// } catch (const dmlc::Error &e) { +// this->report_error(e.what(), arg->span); +// arg_type = arg->type; +// } +// arg_types.push_back(arg_type); +// this->local_stack.insert(arg->id, arg_type); +// } + +// // typecheck body and ensure that it matches stated return type +// // TODO(sslyu): should the unified return type override the annotated one? +// Type checked_return = this->Check(f->body); +// Type ret_type = resolve(f->ret_type); +// Type unified = this->unify(simple_eval_shape(ret_type), +// simple_eval_shape(checked_return), f->span); +// return TypeArrowNode::make(arg_types, unified); +// }); +// if (generalize) { +// auto free_vars = free_type_vars(resolve(fn_type)); +// std::set dedup_free_vars; + +// for (auto free_var : free_vars) { +// auto repr = this->unifier->subst(free_var); +// if (auto new_free_var_node = repr.as()) { +// dedup_free_vars.insert(GetRef(new_free_var_node)); +// } else { +// // debug(repr); +// throw dmlc::Error( +// "internal error: this list should only contain type var nodes"); +// } +// } + +// Map vars_to_id; + +// GenFresh gf; +// for (auto free_var : dedup_free_vars) { +// vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); +// } + +// fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); +// for (std::pair pair : vars_to_id) { +// // NB: In generalization we want to find type variables with +// // *no constraints* on them, and convert them to universally quantified +// // variables. +// // +// // i.e the program can be abstracted over the details of *that* type. + +// // For example a program that works irrespective of shape or datatype. + +// // In order to do this we find the set of free type variables in the +// // term, and then unify them with the fresh type ids we generate. +// // +// // Remember importantly these type variables still may appear in many +// // places in the program including both types and expressions. + +// // Our method for resolving these is to unify them with the variables +// // as we build the new quanitifer, changing from a program with "holes" +// // to one that is properly abstracted over. + +// // Finally later on we can iterate over the whole term and change from +// // type variables to these type ids. +// this->unify(pair.first, pair.second, pair.second->span); +// fn_type = TypeQuantifierNode::make(pair.second, fn_type); +// } +// } else { +// for (auto i = f->ty_params.size(); i > 0; i--) { +// auto ty_param = f->ty_params[i - 1]; +// auto ty_param_node = ty_param.as(); +// if (!ty_param_node) { +// throw dmlc::Error("internal error should be TypeParam"); +// } +// auto fresh_tid = +// TypeParamNode::make(ty_param_node->name, ty_param_node->kind); +// fn_type = +// type_subst(fn_type, GetRef(ty_param_node), fresh_tid); +// fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); +// } +// } + +// return fn_type; +// } + +// Type Typechecker::VisitExpr_(const FunctionNode *op) { +// return this->VisitFunction(GetRef(op), false); +// } + +// Type Typechecker::instantiate(Type t, tvm::Array &ty_args) { +// const TypeQuantifierNode *ty_quant; +// while ((ty_quant = t.as())) { +// TypeParam id = ty_quant->id; +// TypeVar fresh = TypeVarNode::make(id->kind); +// this->unifier->insert(fresh); +// ty_args.push_back(fresh); +// t = type_subst(ty_quant->boundType, id, fresh); +// } + +// if (!check_kind(t)) { +// this->fatal_error("Kind rules broken when instantiating type variables", +// t->span); +// } + +// return t; +// } + +// Type Typechecker::VisitExpr_(const CallNode *op) { +// Call c = GetRef(op); +// Type fn_ty = this->Check(c->fn); + +// RELAY_LOG(INFO) << "Typechecker::VisitExpr_ op=" << c << std::endl +// << "fn_ty=" << fn_ty << std::endl; + +// // for each type id, insert a type variable and unify with the argument types +// // in order +// // to obtain the concrete instantiation +// tvm::Array ty_args; +// if (const TypeQuantifierNode *ty_quant = fn_ty.as()) { +// fn_ty = instantiate(GetRef(ty_quant), ty_args); +// } + +// if (!fn_ty.as()) { +// this->fatal_error("only expressions with function types can be called", +// c->fn->span); +// } + +// // evaluate all shapes up front (require that types be fully concrete) +// Type evaluated = evaluate_concrete_shape(fn_ty, op->attrs); +// std::vector arg_types; + +// TypeArrow arrow = GetRef(evaluated.as()); + +// // TODO(sslyu): figure out how to handle type ids +// // fn_ty = instantiate(fn_ty, ty_args); +// for (auto arg : c->args) { +// auto ty = this->Check(arg); +// arg_types.push_back(ty); +// } + +// auto type_arity = arrow->arg_types.size(); +// auto number_of_args = arg_types.size(); +// if (type_arity != number_of_args) { +// if (type_arity < number_of_args) { +// this->fatal_error("the function is provided too many arguments", c->span); +// } else { +// this->fatal_error("the function is provided too few arguments", c->span); +// } +// } + +// for (size_t i = 0; i < arrow->arg_types.size(); i++) { +// this->unify(arrow->arg_types[i], arg_types[i], c->args[i]->span); +// } + +// // After we unify the arguments we should know more about the type +// // arguments, let's run a quick pass over them to find new representatives. +// for (size_t i = 0; i < ty_args.size(); i++) { +// ty_args.Set(i, this->unifier->subst(ty_args[i])); +// } + +// // Write the type arguments into the call node, recording what inference +// // solves. This solution might need some work. +// c->ty_args = ty_args; + +// return arrow->ret_type; +// } + +// Type Typechecker::VisitExpr_(const DebugNode *op) { +// return this->Check(op->node); +// } + +// Type Typechecker::VisitExpr_(const LetNode *op) { +// Let let = GetRef(op); + +// Type checked_ty; +// Type annotated_ty = resolve(let->type); + +// // if we are let-defining a function, treat it as a let-rec and insert +// // the id with the annotated type in case there is recursion; +// // no such recursion permitted with anything that's not a function! +// if (let->value.as()) { +// with_frame([&]() { +// local_stack.insert(let->id, annotated_ty); +// checked_ty = Check(let->value); +// }); +// } else { +// checked_ty = Check(let->value); +// } + +// // ensure annotated type and checked type are compatible +// // TODO(sslyu): should the annotated type override the unified one? +// Type unified_ty = +// this->unify(checked_ty, simple_eval_shape(annotated_ty), let->span); + +// return with_frame([&]() { +// local_stack.insert(let->id, unified_ty); +// return Check(let->body); +// }); +// } + +// Type Typechecker::VisitExpr_(const ReverseNode *op) { +// // apply reverse mode to node and typecheck that instead +// std::shared_ptr gf = std::make_shared(); +// return this->Check(ReverseExpr(env, op->node, gf)); +// } + +// Type Typechecker::VisitExpr_(const GradientNode *op) { +// auto node = op->node; +// this->Check(node); +// auto gf = std::make_shared(); +// return FOWithGradientType(node->checked_type()); +// } + +// Type Typechecker::VisitExpr_(const ProjectionNode *op) { +// Projection proj = GetRef(op); + +// Type tup_type = this->Check(proj->tuple); + +// const TupleTypeNode *ptn = tup_type.as(); +// if (!ptn) { +// this->fatal_error("Cannot project into non-product type", op->span); +// } + +// TupleType pt = GetRef(ptn); +// size_t field = (size_t)proj->field; +// if (field >= pt->fields.size()) { +// this->fatal_error("Projecting past bounds of product", op->span); +// } + +// return pt->fields[field]; +// } + +// Type Typechecker::VisitExpr_(const IfNode *op) { +// If ifn = GetRef(op); + +// // Ensure the type of the guard is of Tensor[Bool, ()], +// // that is a rank-0 boolean tensor. +// Type guardType = this->Check(ifn->guard); +// bool is_bool = false; +// bool zero_rank = false; +// if (const TensorTypeNode *ttn = guardType.as()) { +// TensorType tt = GetRef(ttn); + +// if (const BaseTypeNode *btn = tt->dtype.as()) { +// is_bool = btn->type.is_bool(); +// } + +// Type shape = simple_eval_shape(tt->shape); + +// if (const ShapeSeqNode *sn = shape.as()) { +// zero_rank = (sn->shapes.size() == 0); +// } +// } + +// if (!(is_bool && zero_rank)) { +// this->fatal_error("IfNode guard must be a rank 0 bool tensor", +// ifn->guard->span); +// } + +// // unify types of different branches +// Type left = this->Check(ifn->true_b); +// Type right = this->Check(ifn->false_b); +// return this->unify(left, right, ifn->span); +// } + +// Type Typechecker::VisitExpr_(const RefNode *op) { +// Ref r = GetRef(op); +// Type inner = this->Check(r->expr); +// return RefTypeNode::make(inner); +// } + +// Type Typechecker::VisitExpr_(const ReadRefNode *op) { +// ReadRef vr = GetRef(op); +// Type ref_type = this->Check(vr->ref); + +// // reject if not a ref type +// const RefTypeNode *rtn = ref_type.as(); +// if (!rtn) { +// this->fatal_error( +// "the de-reference operation can only be used with references", +// op->span); +// } + +// RefType rt = GetRef(rtn); +// return rt->data_type; +// } + +// Type Typechecker::VisitExpr_(const WriteRefNode *op) { +// WriteRef sr = GetRef(op); +// Type ref_type = this->Check(sr->ref); + +// const RefTypeNode *rtn = ref_type.as(); +// if (!rtn) { +// this->fatal_error("Cannot mutate non-ref", op->span); +// } +// RefType rt = GetRef(rtn); + +// // ensure ref type's inner type and expr's type are compatible; return unit +// Type expr_type = this->Check(sr->val); +// this->unify(rt->data_type, expr_type, sr->span); +// return UnitType(); +// } + +// Type Typechecker::resolve(const Type &t) { +// return ::tvm::relay::resolve(this->unifier, t); +// } + +// Expr Typechecker::resolve(const Expr &e) { +// return ::tvm::relay::resolve(this->unifier, e); +// } + +// Type Typechecker::simple_eval_shape(const Type &shape) { +// // TODO(sslyu): Do we want to propagate attributes? +// Attributes empty = AttributesNode::make({}); +// return evaluate_concrete_shape(shape, empty); +// } + +// Operator Typechecker::CheckOp(Operator op) { +// if (!check_kind(op->type)) { +// report_error("the type of the operator is ill formed", op->type->span); +// } + +// // Fix me +// return op; +// } + +// Defn Typechecker::CheckDefn(Defn defn) { +// // This is to handle recursion, but we need to speculatively +// // put it in env, then remove it. +// env->items.insert({defn->id, defn}); + +// Type expected_ty = this->resolve(defn->type); + +// Expr body = defn->body; + +// auto checked_ty = Check(body); + +// try { +// Type uret_type = unify(expected_ty, checked_ty, defn->body->span); +// CHECK(is_fully_resolved(uret_type)); +// // Now let's clean up our work from earlier. +// env->items.erase(defn->id); +// return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); +// } catch (const UnificationError& err) { +// std::string msg = std::string("mismatch between `") + +// PrintType(env, expected_ty, WrapWidth(40)) + "` and `" + +// PrintType(env, checked_ty, WrapWidth(40)) + "`"; +// fatal_error(msg, defn->span); +// } +// } + +// Type check(const Environment &env, const Expr &e) { +// Typechecker tc(env); +// return tc.Check(e); +// } + +// Item check(const Environment &env, const Item &i) { +// Typechecker tc(env); + +// try { +// if (const DefnNode *defn = i.as()) { +// return tc.CheckDefn(GetRef(defn)); +// } else if (const OperatorNode *op_node = i.as()) { +// return tc.CheckOp(GetRef(op_node)); +// } else { +// throw dmlc::Error("internal error: unknown Item type"); +// } +// } catch (const FatalTypeError &err) { +// env->display_errors(); +// throw dmlc::Error( +// "We encountered a fatal error while type checking your program, please " +// "read above for more details."); +// } +// } + +// inline void Typechecker::report_error(const std::string &msg, Span sp) { +// this->env->report_error(msg, sp); +// } + +// void Typechecker::fatal_error(const std::string &msg, Span sp) { +// this->env->report_error(msg, sp); +// throw FatalTypeError( +// "internal error: this exception should" +// "be handled and errors reported with Environment::display_errors\n" + +// msg); +// } + +// Type Typechecker::unify(const Type &t1, const Type &t2, Span sp) { +// try { +// return this->unifier->unify(t1, t2); +// } catch (const dmlc::Error &e) { +// std::stringstream ss; +// ss << "Error unifying `"; +// ss << PrintType(env, t1, WrapWidth(40)); +// ss << "` and `"; +// ss << PrintType(env, t2, WrapWidth(40)); +// ss << "`: " << e.what(); +// this->fatal_error(ss.str(), sp); +// } +// } + +// // template + +// // Add safe dynamic Array downcast. +// // Add static upcast? + +// // Add to type utils. +// Array type_parameters(const Type &t) { +// Array params; +// auto type = t; +// const TypeQuantifierNode *ty_quant; +// while ((ty_quant = type.as())) { +// params.push_back(ty_quant->id); +// type = ty_quant->boundType; +// } + +// return params; +// } + +// template +// Array ArrayMap(const Array &data, F f) { +// // probably a way to use std::transform. +// Array output; +// for (const I &el : data) { +// output.push_back(f(el)); +// } +// return output; +// } + +// // There are some important questions around generalization +// // that we need to answer. +// Expr generalize(const Environment &env, const Expr &e) { +// if (auto fn_node = e.as()) { +// Typechecker tc(env); +// auto ty = tc.VisitFunction(GetRef(fn_node), true); +// auto ty_params = type_parameters(ty); +// auto params = ArrayMap(fn_node->params, [&](const Param &p) { +// return ParamNode::make(p->id, tc.resolve(p->type)); +// }); +// auto body = tc.resolve(fn_node->body); +// auto ret_type = tc.resolve(fn_node->ret_type); +// auto fn = FunctionNode::make(ty_params, params, ret_type, body); +// // we should check in empty context to ensure typing is preserved. +// // check(env, fn); +// return fn; +// } else { +// throw dmlc::Error("can only apply generalize to a function."); +// } +// } + +// TVM_REGISTER_API("relay._tyck.check_expr") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// Expr e = args[1]; +// *ret = check(env, e); +// }); + +// TVM_REGISTER_API("relay._tyck.check_item") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// Item i = args[1]; +// *ret = check(env, i); +// }); + +// TVM_REGISTER_API("relay._tyck.get_checked_type") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Expr e = args[0]; +// *ret = e->checked_type(); +// }); + +// TVM_REGISTER_API("relay._tyck.generalize") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// *ret = generalize(args[0], args[1]); +// }); + +IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { + std::shared_ptr n = std::make_shared(); + n->kind = std::move(kind); + return IncompleteType(n); +} + +TVM_REGISTER_API("relay._make.IncompleteType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[0]; + *ret = IncompleteTypeNode::make(static_cast(kind)); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IncompleteTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "IncompleteTypeNode(" << node->kind << ", " << &node << ")"; + }); + +} // namespace relay +} // namespace tvm From ee218d04d56ed1550b536e13a2bde35d200222a9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:06:44 -0700 Subject: [PATCH 19/88] Add type_visitor.h --- src/relay/compiler/type_visitor.h | 107 ++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 src/relay/compiler/type_visitor.h diff --git a/src/relay/compiler/type_visitor.h b/src/relay/compiler/type_visitor.h new file mode 100644 index 000000000000..5ae100a8de6d --- /dev/null +++ b/src/relay/compiler/type_visitor.h @@ -0,0 +1,107 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_visitor.h + * \brief A wrapper around TypeFunctor for common use cases. + */ +#ifndef TVM_RELAY_TYPE_VISITOR_H_ +#define TVM_RELAY_TYPE_VISITOR_H_ + +#include +#include "./type_functor.h" + +namespace tvm { +namespace relay { + +/*! \brief A type visitor for vistiors which make use of internal + * mutable state. + * + * We recursively visit each type contained inside the visitor. + */ +template +struct TypeVisitor : TypeFunctor { + // void VisitType_(const TypeVarNode* op, Args... args) override {} + void VisitType_(const TypeParamNode* op, Args... args) override {} + + void VisitType_(const FuncTypeNode* op, Args... args) override { + // this->VisitType(op->id, args...); + // this->VisitType(op->boundType, args...); + // for (auto arg_type : op->arg_types) { + // this->VisitType(arg_type, args...); + // } + // this->VisitType(op->ret_type, args...); + } + + void VisitType_(const TensorTypeNode* op, Args... args) override { + // this->VisitType(op->dtype, args...); + // this->VisitType(op->shape, args...); + } + +// void VisitType_(const TupleTypeNode* op, Args... args) override { +// for (const Type& t : op->fields) { +// this->VisitType(t, args...); +// } +// } + +// void VisitType_(const TypeCallNode* op, Args... args) override { +// for (const Type& t : op->args) { +// this->VisitType(t, args...); +// } +// } + + void VisitType_(const TypeFunctionNode* op, Args... args) override {} + void VisitType_(const IncompleteTypeNode* op, Args... args) override {} +}; + +// A functional visitor for rebuilding an AST in place. +struct TypeFVisitor : TypeFunctor { + Type VisitType_(const TensorTypeNode* op) override { + // TODO (@jroesch): maybe we should recursively visit + return TensorTypeNode::make(op->shape, op->dtype); + } + + Type VisitType_(const TypeParamNode* op) override { + return GetRef(op); + } + +// Type VisitType_(const TypeArrowNode* op) override { +// std::vector args; +// for (auto arg_type : op->arg_types) { +// args.push_back(VisitType(arg_type)); +// } +// return TypeArrowNode::make(tvm::Array(args), VisitType(op->ret_type)); +// } + +// Type VisitType_(const TypeQuantifierNode* op) override { +// auto new_id = this->VisitType(op->id); +// if (const TypeParamNode* tin = new_id.as()) { +// return TypeQuantifierNode::make(GetRef(tin), +// this->VisitType(op->boundType)); +// } else { +// throw dmlc::Error("Cannot quantify something that is not a type ID"); +// } +// } + +// Type VisitType_(const TupleTypeNode* op) override { +// std::vector new_fields; +// for (const Type& t : op->fields) { +// new_fields.push_back(this->VisitType(t)); +// } +// return TupleTypeNode::make(new_fields); +// } + +// Type VisitType_(const TypeCallNode* op) override { +// auto func = this->VisitType(op->func); +// std::vector new_args; +// for (const Type& t : op->args) { +// new_args.push_back(this->VisitType(t)); +// } +// return TypeCallNode::make(func, new_args); +// } + Type VisitType_(const IncompleteTypeNode* op) override { + return GetRef(op); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPE_VISITOR_H_ From c1e20475fef43d165030bbaeda519a8b353bd845 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:07:13 -0700 Subject: [PATCH 20/88] Add expr_visitor.h --- include/tvm/relay/expr_visitor.h | 166 +++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 include/tvm/relay/expr_visitor.h diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h new file mode 100644 index 000000000000..d7ac1465f70a --- /dev/null +++ b/include/tvm/relay/expr_visitor.h @@ -0,0 +1,166 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr_visitor.h + * \brief A simple visitor wrapper around ExprFunctor designed for visitors which + * maintain mutable state. + */ +#ifndef TVM_RELAY_EXPR_VISITOR_H_ +#define TVM_RELAY_EXPR_VISITOR_H_ + +#include "expr_functor.h" + +namespace tvm { +namespace relay { + +template +class ExprVisitor : public ExprFunctor { + public: + void VisitExpr_(const LocalVarNode* op, Args... args) override { return; } + + void VisitExpr_(const GlobalVarNode* op, Args... args) override { return; } + + void VisitExpr_(const ConstantNode* op, Args... args) override { return; } + + void VisitExpr_(const TupleNode* op, Args... args) override { + for (auto field : op->fields) { + this->VisitExpr(field, args...); + } + } + + void VisitExpr_(const ParamNode* op, Args... args) override { + this->VisitExpr(op->var, args...); + } + + void VisitExpr_(const FunctionNode* op, Args... args) override { + for (auto param : op->params) { + this->VisitExpr(param, args...); + } + + this->VisitExpr(op->body, args...); + } + + void VisitExpr_(const CallNode* op, Args... args) override { + this->VisitExpr(op->op, args...); + for (auto arg : op->args) { + this->VisitExpr(arg, args...); + } + } + + void VisitExpr_(const LetNode* op, Args... args) override { + this->VisitExpr(op->var, args...); + this->VisitExpr(op->value, args...); + this->VisitExpr(op->body, args...); + } + + void VisitExpr_(const IfNode* op, Args... args) override { + this->VisitExpr(op->cond, args...); + this->VisitExpr(op->true_value, args...); + this->VisitExpr(op->false_value, args...); + } + + void VisitExpr_(const OperatorNode* op, Args... args) override { return; } +}; + +template +class ExprFVisitor : public ExprFunctor { + public: + Expr VisitExpr_(const LocalVarNode* op, Args... args) override { + return GetRef(op); + } + + Expr VisitExpr_(const GlobalVarNode* op, Args... args) override { + return GetRef(op); + } + + Expr VisitExpr_(const OperatorNode* op, Args... args) override { + return GetRef(op); + } + + Expr VisitExpr_(const TupleNode* op, Args... args) override { + tvm::Array fields; + for (auto field : op->fields) { + fields.push_back(this->VisitExpr(field, args...)); + } + + return TupleNode::make(fields); + } + + Expr VisitExpr_(const ParamNode* op, Args... args) override { + Expr var_expr = this->VisitExpr(op->var, args...); + if (const LocalVarNode* var_node = var_expr.as()) { + auto var = GetRef(var_node); + auto type = this->VisitType(op->type, args...); + return ParamNode::make(var, type); + } else { + throw dmlc::Error("the default param visitor has bug"); + } + } + + Expr VisitExpr_(const FunctionNode* op, Args... args) override { + tvm::Array ty_params; + for (auto ty : op->type_params) { + ty_params.push_back(this->VisitType(ty, args...)); + } + + tvm::Array params; + for (auto param : op->params) { + Expr param_expr = this->VisitExpr(param, args...); + if (const ParamNode* param_node = param_expr.as()) { + auto param = GetRef(param_node); + params.push_back(param); + } else { + throw dmlc::Error("the default func visitor has bug"); + } + } + + auto ret_type = this->VisitType(op->ret_type, args...); + auto body = this->VisitExpr(op->body, args...); + return FunctionNode::make(ty_params, params, ret_type, body); + } + + Expr VisitExpr_(const CallNode* call_node, Args... args) override { + auto fn = this->VisitExpr(call_node->op, args...); + + tvm::Array ty_args; + for (auto ty_arg : call_node->type_args) { + auto new_ty_arg = this->VisitType(ty_arg, args...); + ty_args.push_back(new_ty_arg); + } + + tvm::Array call_args; + for (auto arg : call_node->args) { + call_args.push_back(this->VisitExpr(arg, args...)); + } + + auto call = CallNode::make(fn, call_args, call_node->attrs); + call->ty_args = ty_args; + + return call; + } + + Expr VisitExpr_(const LetNode* op, Args... args) override { + Expr var_expr = this->VisitExpr(op->var, args...); + if (const LocalVarNode* var_node = var_expr.as()) { + auto var = GetRef(var_node); + auto type = this->VisitType(op->value_type, args...); + auto value = this->VisitExpr(op->value, args...); + auto body = this->VisitExpr(op->body, args...); + return LetNode::make(var, type, value, body); + } else { + throw dmlc::Error("the default let visitor has error"); + } + } + + Expr VisitExpr_(const IfNode* op, Args... args) override { + auto guard = this->VisitExpr(op->cond, args...); + auto true_b = this->VisitExpr(op->true_value, args...); + auto false_b = this->VisitExpr(op->false_value, args...); + return IfNode::make(guard, true_b, false_b); + } + + virtual Type VisitType(const Type& t, Args... args) { return t; } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_VISITOR_H_ From 36d92c2c15bb18d1c8748caafa44b6ecf8715467 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:07:25 -0700 Subject: [PATCH 21/88] Start reparing unifier and tests --- python/tvm/relay/ir_builder.py | 11 ++++++++ tests/python/relay/test_unifier.py | 43 ++++++++++++------------------ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 497479140ec9..3c842e480c70 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -100,5 +100,16 @@ def get(self): return _mk_let(bindings, self.ret_value) +# def int_type(): +# return TensorType(IntType(32), ShapeSeq([])) + +# def float_type(): +# return TensorType(FloatType(32), ShapeSeq([])) + +# def bool_type(): +# return TensorType(BoolType(), ShapeSeq([])) + +# def make_shape(dims): +# return ShapeSeq([ShapeSingleton(dim) for dim in dims]) diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index 875502808563..b2ed075ca3de 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -2,30 +2,14 @@ Test the type unifier, which solves systems of equations between incomplete types. """ -import relay.ir -import relay.unifier - - -def unify_types(t1, t2): - unifier = TypeUnifier() - return unifier.unify(t1, t2) - -def int_type(): - return TensorType(IntType(32), ShapeSeq([])) - -def float_type(): - return TensorType(FloatType(32), ShapeSeq([])) - -def bool_type(): - return TensorType(BoolType(), ShapeSeq([])) - -def make_shape(dims): - return ShapeSeq([ShapeSingleton(dim) for dim in dims]) +import tvm.relay.ir +from tvm.relay.unifier import UnionFind, TypeUnifier +import tvm.relay.make as mk def test_insert_and_find(): uf = UnionFind() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) + v1 = mk.TypeVar(ir.Kind.Type) + v2 = mk.TypeVar(ir.Kind.Type) uf.insert(v1) uf.insert(v2) assert uf.find(v1) == v1 @@ -33,8 +17,8 @@ def test_insert_and_find(): def test_insert_error(): uf = UnionFind() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) + v1 = mk.TypeVar(ir.Kind.Type) + v2 = mk.TypeVar(ir.Kind.Type) uf.insert(v1) try: uf.find(v2) @@ -44,9 +28,9 @@ def test_insert_error(): def test_unify(): uf = UnionFind() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) - v3 = TypeVar(ir.Kind.Type) + v1 = mk.TypeVar(ir.Kind.Type) + v2 = mk.TypeVar(ir.Kind.Type) + v3 = mk.TypeVar(ir.Kind.Type) uf.insert(v1) uf.insert(v2) uf.insert(v3) @@ -97,6 +81,13 @@ def test_unify_multiple_levels(): for i in range(3): assert uf.find(v[i + 6]) == new_rep2 +# We have checked that the basic machinery in the UnionFind works +# and now we will test the type unifier which will fill in holes +# between type equalities by the process of unification. +def unify_types(t1, t2): + unifier = TypeUnifier() + return unifier.unify(t1, t2) + # TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work def test_unify_int(): intty = IntType(1) From 221d15773fad6e4a192fdc9a330c06dcfc9c4781 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:21:12 -0700 Subject: [PATCH 22/88] Fix test_unifier.py, now runs but all tests fail --- python/tvm/relay/make.py | 5 + tests/python/relay/test_alpha_eq.py | 1148 +++++++++++++-------------- tests/python/relay/test_unifier.py | 130 +-- 3 files changed, 643 insertions(+), 640 deletions(-) diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py index 14d9ac040dc9..a2b87f2700af 100644 --- a/python/tvm/relay/make.py +++ b/python/tvm/relay/make.py @@ -18,3 +18,8 @@ Call = _make.Call Let = _make.Let If = _make.If +IncompleteType = _make.IncompleteType + +# Unifier +UnionFind = _make.UnionFind +TypeUnifier = _make.TypeUnifier diff --git a/tests/python/relay/test_alpha_eq.py b/tests/python/relay/test_alpha_eq.py index f1dc81c3c483..e4fbbcca93ce 100644 --- a/tests/python/relay/test_alpha_eq.py +++ b/tests/python/relay/test_alpha_eq.py @@ -1,576 +1,574 @@ """Test alpha-equivalence of expressions and types.""" -# pylint: disable=invalid-name, missing-docstring -# pylint: disable=wildcard-import, unused-wildcard-import -from relay.make import * -from relay.ir import alpha_eq, ShapeOp, Kind -from relay.typing import TYPE_DEFAULTS -from relay import ir - -INT_TYPE_WIDTH = TYPE_DEFAULTS["INT_WIDTH"] -INT_TYPE_LANES = TYPE_DEFAULTS["INT_LANES"] - -def int_type(width=32) -> ir.Type: - return TensorType(IntType(width), ShapeSeq([])) - -def float_type(width=32) -> ir.Type: - return TensorType(FloatType(width), ShapeSeq([])) - -def bool_type() -> ir.Type: - return TensorType(BoolType(), ShapeSeq([])) - -def nest_quantifiers(ids, body) -> ir.Type: - ret = body - for tid in reversed(ids): - ret = TypeQuantifier(tid, ret) - return ret - -def test_local_id_not_eq() -> None: - assert not alpha_eq(LocalId("x"), LocalId("y")) - -def test_local_id_eq() -> None: - x = LocalId("x") - assert alpha_eq(x, x) - -def test_global_id_not_eq() -> None: - left = GlobalId("xyz") - right = GlobalId("xyz") - assert not alpha_eq(left, right) - -def test_global_id_eq() -> None: - ident = GlobalId("xyz") - assert alpha_eq(ident, ident) - -def test_operator_id_not_eq() -> None: - left = OperatorId("xyz") - right = OperatorId("xyz") - # equality on operator id is pointer equality - assert not alpha_eq(left, right) - -def test_operator_id_eq() -> None: - x = OperatorId("xyz") - assert alpha_eq(x, x) - -def test_float_literal_eq() -> None: - x = FloatLit(1.0) - y = FloatLit(1.0) - assert alpha_eq(x, y) - -def test_float_literal_not_eq() -> None: - x = FloatLit(1.0) - y = FloatLit(2.0) - assert not alpha_eq(x, y) - -def test_int_literal_eq() -> None: - x = IntLit(1) - y = IntLit(1) - assert alpha_eq(x, y) - -def test_int_literal_not_eq() -> None: - x = IntLit(1) - y = IntLit(2) - assert not alpha_eq(x, y) - -def test_bool_literal_eq() -> None: - x = BoolLit(True) - y = BoolLit(True) - assert alpha_eq(x, y) - -def test_bool_literal_not_eq() -> None: - x = BoolLit(True) - y = BoolLit(False) - assert not alpha_eq(x, y) - -def test_tensor_literal_eq() -> None: - x = TensorLit([IntLit(1), IntLit(2)]) - y = TensorLit([IntLit(1), IntLit(2)]) - assert alpha_eq(x, y) - -def test_tensor_literal_not_eq() -> None: - x = TensorLit([IntLit(1), IntLit(2)]) - y = TensorLit([IntLit(1), IntLit(3)]) - z = TensorLit([IntLit(1)]) - assert not alpha_eq(x, y) - assert not alpha_eq(x, z) - -def test_product_literal_eq() -> None: - x = Tuple([IntLit(1), IntLit(2)]) - y = Tuple([IntLit(1), IntLit(2)]) - assert alpha_eq(x, y) - -def test_product_literal_not_eq() -> None: - x = Tuple([IntLit(1), IntLit(2)]) - y = Tuple([IntLit(2), IntLit(2)]) - z = Tuple([IntLit(1), IntLit(2), IntLit(3)]) - assert not alpha_eq(x, y) - assert not alpha_eq(x, z) - -def test_projection_eq() -> None: - prod = Tuple([IntLit(3), FloatLit(3.5)]) - - assert alpha_eq(Projection(prod, 0), Projection(prod, 0)) - assert alpha_eq(Projection(prod, 1), Projection(prod, 1)) - -def test_projection_not_eq() -> None: - prod1 = Tuple([IntLit(3), IntLit(4)]) - prod2 = Tuple([IntLit(3)]) - prod3 = Tuple([IntLit(3), IntLit(4), FloatLit(3.5)]) - - assert not alpha_eq(Projection(prod1, 0), Projection(prod1, 1)) - assert not alpha_eq(Projection(prod1, 0), Projection(prod2, 0)) - assert not alpha_eq(Projection(prod1, 0), Projection(prod3, 0)) - assert not alpha_eq(Projection(prod1, 1), Projection(prod3, 1)) - -def test_cast_not_eq() -> None: - left = Cast(IntType(1), IntLit(2)) - right = Cast(IntType(1), IntLit(1)) - assert not alpha_eq(left, right) - - # same literal, different type - left = Cast(IntType(1), IntLit(2)) - right = Cast(IntType(2), IntLit(2)) - assert not alpha_eq(left, right) - -def test_cast_eq() -> None: - left = Cast(IntType(1), IntLit(2)) - right = Cast(IntType(1), IntLit(2)) - assert alpha_eq(left, right) - -def test_param_not_eq() -> None: - left = Param(LocalId("foo"), int_type()) - right = Param(LocalId("foo"), bool_type()) - assert not alpha_eq(left, right) - -def test_param_eq() -> None: - left = Param(LocalId("foo"), int_type()) - right = Param(LocalId("bar"), int_type()) - assert alpha_eq(left, right) - -def test_function_not_eq() -> None: - params1 = [Param(LocalId("x"), int_type())] - fn1 = Function([], params1, int_type(), LocalId("x")) - params2 = [Param(LocalId("y"), bool_type())] - fn2 = Function([], params2, int_type(), LocalId("y")) - assert not alpha_eq(fn1, fn2) - - params3 = [Param(LocalId("x"), int_type()), Param(LocalId("y"), int_type())] - fn3 = Function([], params3, int_type(), LocalId("z")) - assert not alpha_eq(fn1, fn3) - -def test_function_eq() -> None: - x = LocalId("x") - y = LocalId("y") - params1 = [Param(x, int_type())] - fn1 = Function([], params1, int_type(), x) - params2 = [Param(y, int_type())] - fn2 = Function([], params2, int_type(), y) - assert alpha_eq(fn1, fn2) - -def test_call_not_eq() -> None: - x = LocalId("x") - y = LocalId("y") - params1 = [Param(x, int_type())] - fn1 = Function([], params1, int_type(), x) - args1 = [IntLit(1)] - call1 = Call(fn1, args1) - - args2 = [IntLit(2)] - call2 = Call(fn1, args2) - assert not alpha_eq(call1, call2) - - params2 = [Param(y, int_type())] - fn2 = Function([], params2, float_type(), FloatLit(0.0)) - call3 = Call(fn2, args1) - assert not alpha_eq(call1, call3) - assert not alpha_eq(call2, call3) - -def test_call_eq() -> None: - x = LocalId("x") - y = LocalId("y") - params1 = [Param(x, int_type())] - fn1 = Function([], params1, int_type(), x) - args = [IntLit(1)] - call1 = Call(fn1, args) - - params2 = [Param(y, int_type())] - fn2 = Function([], params2, int_type(), y) - call2 = Call(fn2, args) - assert alpha_eq(call1, call2) - -def test_debug_not_eq() -> None: - left = Debug(IntLit(1)) - right = Debug(IntLit(2)) - assert not alpha_eq(left, right) - -def test_debug_eq() -> None: - left = Debug(IntLit(1)) - right = Debug(IntLit(1)) - assert alpha_eq(left, right) - -def test_let_not_eq() -> None: - x = LocalId("x") - y = LocalId("y") - let1 = Let(x, int_type(), IntLit(10), IntLit(11)) - let2 = Let(y, int_type(), IntLit(10), IntLit(12)) - assert not alpha_eq(let1, let2) - - let3 = Let(x, int_type(), IntLit(10), x) - let4 = Let(y, int_type(), IntLit(12), y) - assert not alpha_eq(let3, let4) - -def test_let_eq() -> None: - x = LocalId("x") - y = LocalId("y") - let1 = Let(x, int_type(), IntLit(10), x) - let2 = Let(y, int_type(), IntLit(10), y) - assert alpha_eq(let1, let2) - -def test_ref_eq() -> None: - r1 = Ref(IntLit(5)) - r2 = Ref(IntLit(5)) - assert alpha_eq(r1, r2) - -def test_ref_not_eq() -> None: - r1 = Ref(IntLit(5)) - r2 = Ref(FloatLit(3.5)) - r3 = Ref(r1) - assert not alpha_eq(r1, r2) - assert not alpha_eq(r1, r3) - assert not alpha_eq(r2, r3) - -def test_val_ref_eq() -> None: - vr1 = ReadRef(Ref(IntLit(35))) - vr2 = ReadRef(Ref(Tuple([IntLit(12), FloatLit(2.5)]))) - assert alpha_eq(vr1, vr1) - assert alpha_eq(vr2, vr2) - -def test_val_ref_not_eq() -> None: - vr1 = ReadRef(Ref(IntLit(5))) - vr2 = ReadRef(Ref(vr1)) - vr3 = ReadRef(Ref(FloatLit(5.0))) - assert not alpha_eq(vr1, vr2) - assert not alpha_eq(vr1, vr3) - assert not alpha_eq(vr2, vr3) - -def test_set_ref_eq() -> None: - sr1 = WriteRef(Ref(FloatLit(5.0)), FloatLit(6.0)) - sr2 = WriteRef(Ref(Tuple([IntLit(3), BoolLit(False)])), - Tuple([IntLit(5), BoolLit(True)])) - assert alpha_eq(sr1, sr1) - assert alpha_eq(sr2, sr2) - -def test_set_ref_not_eq() -> None: - r1 = Ref(FloatLit(5.0)) - r2 = Ref(IntLit(5)) - r3 = Ref(IntLit(6)) - - assert not alpha_eq(WriteRef(r1, FloatLit(6.0)), - WriteRef(r2, IntLit(6))) - assert not alpha_eq(WriteRef(r2, IntLit(6)), WriteRef(r2, IntLit(7))) - assert not alpha_eq(WriteRef(r2, IntLit(7)), WriteRef(r3, IntLit(7))) - -# Type alpha-equality tests - -def test_base_type_eq() -> None: - assert alpha_eq(IntType(32), IntType(32)) - assert alpha_eq(BoolType(), BoolType()) - assert alpha_eq(FloatType(32), FloatType(32)) - -def test_tensor_type_eq() -> None: - tt1 = TensorType( - IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) - tt2 = TensorType( - FloatType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) - assert alpha_eq(tt1, tt1) - assert alpha_eq(tt2, tt2) - -def test_tensor_type_not_eq() -> None: - tt1 = TensorType( - IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) - tt2 = TensorType( - FloatType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) - tt3 = TensorType( - IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) - assert not alpha_eq(tt1, tt2) - assert not alpha_eq(tt1, tt3) - -def test_ref_type_eq() -> None: - rt1 = RefType(int_type()) - rt2 = RefType(float_type()) - assert alpha_eq(rt1, rt1) - assert alpha_eq(rt2, rt2) - -def test_ref_type_not_eq() -> None: - rt1 = RefType(int_type()) - rt2 = RefType(float_type()) - assert not alpha_eq(rt1, rt2) - -def test_product_type_eq() -> None: - pt1 = TupleType([int_type(), RefType(float_type())]) - pt2 = TupleType([float_type(), float_type(), int_type()]) - assert alpha_eq(pt1, pt1) - assert alpha_eq(pt2, pt2) - -def test_product_type_not_eq() -> None: - pt1 = TupleType([int_type(), int_type()]) - pt2 = TupleType([int_type(), int_type(), float_type()]) - pt3 = TupleType([bool_type(), float_type()]) - assert not alpha_eq(pt1, pt2) - assert not alpha_eq(pt1, pt3) - -def test_type_id_eq() -> None: - id1 = TypeParam("id1", Kind.Shape) - id2 = TypeParam("id2", Kind.BaseType) - id3 = TypeParam("id2", Kind.Type) - - assert alpha_eq(id1, id1) - assert alpha_eq(id2, id2) - assert alpha_eq(id3, id3) - -def test_type_id_not_eq() -> None: - # name is just a hint, we use pointer equality as the rule - # (unless there is a quantifier to give context) - id1 = TypeParam("id1", Kind.Shape) - id2 = TypeParam("id1", Kind.Shape) - id3 = TypeParam("id3", Kind.BaseType) - - assert not alpha_eq(id1, id2) - assert not alpha_eq(id1, id3) - -def test_arrow_type_eq() -> None: - ar1 = TypeArrow([int_type()], bool_type()) - ar2 = TypeArrow([int_type(), int_type()], TupleType([])) - assert alpha_eq(ar1, ar1) - assert alpha_eq(ar2, ar2) - -def test_arrow_type_not_eq() -> None: - t1 = int_type() - t2 = bool_type() - t3 = [int_type(), bool_type()] - - assert not alpha_eq(TypeArrow([t1], t2), TypeArrow([t1], t1)) - assert not alpha_eq(TypeArrow(t3, t1), TypeArrow([t2], t1)) - assert not alpha_eq(TypeArrow([t1], TypeArrow([t1], t1)), - TypeArrow([t1], t1)) - -def test_type_quantifier_eq() -> None: - id1 = TypeParam("id1", Kind.Shape) - id2 = TypeParam("id2", Kind.Shape) - tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) - tq2 = TypeQuantifier(id2, TensorType(IntType(32), id2)) - - assert alpha_eq(tq1, tq1) - assert alpha_eq(tq1, tq2) - -def test_nested_type_quantifier_eq() -> None: - id1 = TypeParam("id1", Kind.BaseType) - id2 = TypeParam("id2", Kind.Shape) - id3 = TypeParam("id3", Kind.BaseType) - id4 = TypeParam("id4", Kind.Shape) - tq1 = TypeQuantifier(id1, TypeQuantifier(id2, TensorType(id1, id2))) - tq2 = TypeQuantifier(id3, TypeQuantifier(id4, TensorType(id3, id4))) - - assert alpha_eq(tq1, tq1) - assert alpha_eq(tq1, tq2) - -def test_type_quantifier_not_eq() -> None: - id1 = TypeParam("id1", Kind.Shape) - id2 = TypeParam("id2", Kind.BaseType) - id3 = TypeParam("id3", Kind.Shape) - - tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) - tq2 = TypeQuantifier(id2, TensorType(id2, ShapeSeq([ShapeSingleton(3)]))) - tq3 = TypeQuantifier(id1, TensorType(IntType(32), id3)) - tq4 = TypeQuantifier(id1, TensorType(FloatType(32), id1)) - - assert not alpha_eq(tq1, tq2) - assert not alpha_eq(tq1, tq3) - assert not alpha_eq(tq1, tq4) - assert not alpha_eq(tq2, tq3) - assert not alpha_eq(tq2, tq4) - -def test_shape_singleton_eq() -> None: - single1 = ShapeSingleton(10) - single2 = ShapeSingleton(10) - - assert alpha_eq(single1, single1) - assert alpha_eq(single1, single2) - -def test_shape_singelton_not_eq() -> None: - single1 = ShapeSingleton(10) - single2 = ShapeSingleton(11) - - assert not alpha_eq(single1, single2) - -def test_shape_attr_eq() -> None: - attr1 = ShapeAttr("x") - attr2 = ShapeAttr("x") - - assert alpha_eq(attr1, attr1) - assert alpha_eq(attr1, attr2) - -def test_shape_attr_not_eq() -> None: - id1 = "x" - id2 = "y" - attr1 = ShapeAttr(id1) - attr2 = ShapeAttr(id2) - - assert not alpha_eq(attr1, attr2) - -def test_shape_seq_eq() -> None: - empty = ShapeSeq([]) - seq1 = ShapeSeq([ShapeSingleton(5)]) - seq2 = ShapeSeq([ShapeSingleton(5)]) - - assert alpha_eq(empty, empty) - assert alpha_eq(seq1, seq2) - -def test_shape_seq_not_eq() -> None: - empty = ShapeSeq([]) - seq = ShapeSeq([ShapeSingleton(5)]) - single = ShapeSingleton(5) - - assert not alpha_eq(empty, seq) - assert not alpha_eq(seq, single) - -def test_shape_projection_eq() -> None: - proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) - proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) - - assert alpha_eq(proj1, proj2) - -def test_shape_projection_not_eq() -> None: - proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) - proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 1) - proj3 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 0) - proj4 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 1) - - assert not alpha_eq(proj1, proj2) - assert not alpha_eq(proj1, proj3) - assert not alpha_eq(proj1, proj4) - assert not alpha_eq(proj2, proj3) - assert not alpha_eq(proj2, proj4) - assert not alpha_eq(proj3, proj4) - -def test_shape_binary_op_eq() -> None: - empty = ShapeSeq([]) - single = ShapeSingleton(5) - seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) - - op1 = ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty) - op2 = ShapeBinaryOp(ShapeOp.SHSUB, single, single) - op3 = ShapeBinaryOp(ShapeOp.SHMUL, seq, seq) - op4 = ShapeBinaryOp(ShapeOp.SHDIV, seq, seq) - - assert alpha_eq(op1, op1) - assert alpha_eq(op2, op2) - assert alpha_eq(op3, op3) - assert alpha_eq(op4, op4) - -def test_shape_binary_op_not_eq() -> None: - empty = ShapeSeq([]) - single = ShapeSingleton(5) - seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) - - assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), empty) - assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHMUL, seq, ShapeSingleton(1)), seq) - assert not alpha_eq( - ShapeBinaryOp(ShapeOp.SHPLUS, single, single), - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([single]), - ShapeSeq([single]))) - assert not alpha_eq( - ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), - ShapeBinaryOp(ShapeOp.SHSUB, empty, empty)) - assert not alpha_eq( - ShapeBinaryOp(ShapeOp.SHMUL, empty, empty), - ShapeBinaryOp(ShapeOp.SHDIV, empty, empty)) - -def test_shape_nested_in_quantifier() -> None: - b1 = TypeParam("b", Kind.BaseType) - x1 = TypeParam("x", Kind.Shape) - y1 = TypeParam("y", Kind.Shape) - - b2 = TypeParam("b", Kind.BaseType) - x2 = TypeParam("x", Kind.Shape) - y2 = TypeParam("y", Kind.Shape) - - b3 = TypeParam("b", Kind.BaseType) - x3 = TypeParam("x", Kind.Shape) - y3 = TypeParam("y", Kind.Shape) - - tq1 = nest_quantifiers( - [b1, x1, y1], - TypeArrow( - [TensorType(b1, x1), TensorType(b1, y2)], - TensorType( - b1, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x1, ShapeProjection(y1, 1), - ShapeSingleton(5), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - tq2 = nest_quantifiers( - [b2, x2, y2], - TypeArrow( - [TensorType(b2, x2), TensorType(b2, y2)], - TensorType( - b2, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x2, ShapeProjection(y2, 1), - ShapeSingleton(5), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - # different attr, var order, position, and constant - tq3 = nest_quantifiers( - [b3, x3, y3], - TypeArrow( - [TensorType(b3, x3), TensorType(b3, y3)], - TensorType( - b3, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x3, ShapeProjection(y3, 1), - ShapeSingleton(4), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - tq4 = nest_quantifiers( - [b3, x3, y3], - TypeArrow( - [TensorType(b3, x3), TensorType(b3, y3)], - TensorType( - b3, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x3, ShapeProjection(y3, 2), - ShapeSingleton(5), ShapeAttr("att2")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - tq5 = nest_quantifiers( - [b3, x3, y3], - TypeArrow( - [TensorType(b3, x3), TensorType(b3, y3)], - TensorType( - b3, - ShapeBinaryOp(ShapeOp.SHMUL, - ShapeSeq([x3, ShapeProjection(y3, 1), - ShapeSingleton(5), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - tq6 = nest_quantifiers( - [b3, y3, x3], - TypeArrow( - [TensorType(b3, x3), TensorType(b3, y3)], - TensorType( - b3, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x3, ShapeProjection(y3, 1), - ShapeSingleton(5), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - assert alpha_eq(tq1, tq2) - assert not alpha_eq(tq1, tq3) - assert not alpha_eq(tq2, tq3) - assert not alpha_eq(tq1, tq4) - assert not alpha_eq(tq2, tq4) - assert not alpha_eq(tq1, tq5) - assert not alpha_eq(tq2, tq5) - assert not alpha_eq(tq1, tq6) - assert not alpha_eq(tq2, tq6) +from tvm.relay import make as mk +# from relay.ir import alpha_eq, ShapeOp, Kind +# from relay.typing import TYPE_DEFAULTS +# from relay import ir + +# INT_TYPE_WIDTH = TYPE_DEFAULTS["INT_WIDTH"] +# INT_TYPE_LANES = TYPE_DEFAULTS["INT_LANES"] + +# def int_type(width=32) -> ir.Type: +# return TensorType(IntType(width), ShapeSeq([])) + +# def float_type(width=32) -> ir.Type: +# return TensorType(FloatType(width), ShapeSeq([])) + +# def bool_type() -> ir.Type: +# return TensorType(BoolType(), ShapeSeq([])) + +# def nest_quantifiers(ids, body) -> ir.Type: +# ret = body +# for tid in reversed(ids): +# ret = TypeQuantifier(tid, ret) +# return ret + +# def test_local_id_not_eq() -> None: +# assert not alpha_eq(LocalId("x"), LocalId("y")) + +# def test_local_id_eq() -> None: +# x = LocalId("x") +# assert alpha_eq(x, x) + +# def test_global_id_not_eq() -> None: +# left = GlobalId("xyz") +# right = GlobalId("xyz") +# assert not alpha_eq(left, right) + +# def test_global_id_eq() -> None: +# ident = GlobalId("xyz") +# assert alpha_eq(ident, ident) + +# def test_operator_id_not_eq() -> None: +# left = OperatorId("xyz") +# right = OperatorId("xyz") +# # equality on operator id is pointer equality +# assert not alpha_eq(left, right) + +# def test_operator_id_eq() -> None: +# x = OperatorId("xyz") +# assert alpha_eq(x, x) + +# def test_float_literal_eq() -> None: +# x = FloatLit(1.0) +# y = FloatLit(1.0) +# assert alpha_eq(x, y) + +# def test_float_literal_not_eq() -> None: +# x = FloatLit(1.0) +# y = FloatLit(2.0) +# assert not alpha_eq(x, y) + +# def test_int_literal_eq() -> None: +# x = IntLit(1) +# y = IntLit(1) +# assert alpha_eq(x, y) + +# def test_int_literal_not_eq() -> None: +# x = IntLit(1) +# y = IntLit(2) +# assert not alpha_eq(x, y) + +# def test_bool_literal_eq() -> None: +# x = BoolLit(True) +# y = BoolLit(True) +# assert alpha_eq(x, y) + +# def test_bool_literal_not_eq() -> None: +# x = BoolLit(True) +# y = BoolLit(False) +# assert not alpha_eq(x, y) + +# def test_tensor_literal_eq() -> None: +# x = TensorLit([IntLit(1), IntLit(2)]) +# y = TensorLit([IntLit(1), IntLit(2)]) +# assert alpha_eq(x, y) + +# def test_tensor_literal_not_eq() -> None: +# x = TensorLit([IntLit(1), IntLit(2)]) +# y = TensorLit([IntLit(1), IntLit(3)]) +# z = TensorLit([IntLit(1)]) +# assert not alpha_eq(x, y) +# assert not alpha_eq(x, z) + +# def test_product_literal_eq() -> None: +# x = Tuple([IntLit(1), IntLit(2)]) +# y = Tuple([IntLit(1), IntLit(2)]) +# assert alpha_eq(x, y) + +# def test_product_literal_not_eq() -> None: +# x = Tuple([IntLit(1), IntLit(2)]) +# y = Tuple([IntLit(2), IntLit(2)]) +# z = Tuple([IntLit(1), IntLit(2), IntLit(3)]) +# assert not alpha_eq(x, y) +# assert not alpha_eq(x, z) + +# def test_projection_eq() -> None: +# prod = Tuple([IntLit(3), FloatLit(3.5)]) + +# assert alpha_eq(Projection(prod, 0), Projection(prod, 0)) +# assert alpha_eq(Projection(prod, 1), Projection(prod, 1)) + +# def test_projection_not_eq() -> None: +# prod1 = Tuple([IntLit(3), IntLit(4)]) +# prod2 = Tuple([IntLit(3)]) +# prod3 = Tuple([IntLit(3), IntLit(4), FloatLit(3.5)]) + +# assert not alpha_eq(Projection(prod1, 0), Projection(prod1, 1)) +# assert not alpha_eq(Projection(prod1, 0), Projection(prod2, 0)) +# assert not alpha_eq(Projection(prod1, 0), Projection(prod3, 0)) +# assert not alpha_eq(Projection(prod1, 1), Projection(prod3, 1)) + +# def test_cast_not_eq() -> None: +# left = Cast(IntType(1), IntLit(2)) +# right = Cast(IntType(1), IntLit(1)) +# assert not alpha_eq(left, right) + +# # same literal, different type +# left = Cast(IntType(1), IntLit(2)) +# right = Cast(IntType(2), IntLit(2)) +# assert not alpha_eq(left, right) + +# def test_cast_eq() -> None: +# left = Cast(IntType(1), IntLit(2)) +# right = Cast(IntType(1), IntLit(2)) +# assert alpha_eq(left, right) + +# def test_param_not_eq() -> None: +# left = Param(LocalId("foo"), int_type()) +# right = Param(LocalId("foo"), bool_type()) +# assert not alpha_eq(left, right) + +# def test_param_eq() -> None: +# left = Param(LocalId("foo"), int_type()) +# right = Param(LocalId("bar"), int_type()) +# assert alpha_eq(left, right) + +# def test_function_not_eq() -> None: +# params1 = [Param(LocalId("x"), int_type())] +# fn1 = Function([], params1, int_type(), LocalId("x")) +# params2 = [Param(LocalId("y"), bool_type())] +# fn2 = Function([], params2, int_type(), LocalId("y")) +# assert not alpha_eq(fn1, fn2) + +# params3 = [Param(LocalId("x"), int_type()), Param(LocalId("y"), int_type())] +# fn3 = Function([], params3, int_type(), LocalId("z")) +# assert not alpha_eq(fn1, fn3) + +# def test_function_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# params1 = [Param(x, int_type())] +# fn1 = Function([], params1, int_type(), x) +# params2 = [Param(y, int_type())] +# fn2 = Function([], params2, int_type(), y) +# assert alpha_eq(fn1, fn2) + +# def test_call_not_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# params1 = [Param(x, int_type())] +# fn1 = Function([], params1, int_type(), x) +# args1 = [IntLit(1)] +# call1 = Call(fn1, args1) + +# args2 = [IntLit(2)] +# call2 = Call(fn1, args2) +# assert not alpha_eq(call1, call2) + +# params2 = [Param(y, int_type())] +# fn2 = Function([], params2, float_type(), FloatLit(0.0)) +# call3 = Call(fn2, args1) +# assert not alpha_eq(call1, call3) +# assert not alpha_eq(call2, call3) + +# def test_call_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# params1 = [Param(x, int_type())] +# fn1 = Function([], params1, int_type(), x) +# args = [IntLit(1)] +# call1 = Call(fn1, args) + +# params2 = [Param(y, int_type())] +# fn2 = Function([], params2, int_type(), y) +# call2 = Call(fn2, args) +# assert alpha_eq(call1, call2) + +# def test_debug_not_eq() -> None: +# left = Debug(IntLit(1)) +# right = Debug(IntLit(2)) +# assert not alpha_eq(left, right) + +# def test_debug_eq() -> None: +# left = Debug(IntLit(1)) +# right = Debug(IntLit(1)) +# assert alpha_eq(left, right) + +# def test_let_not_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# let1 = Let(x, int_type(), IntLit(10), IntLit(11)) +# let2 = Let(y, int_type(), IntLit(10), IntLit(12)) +# assert not alpha_eq(let1, let2) + +# let3 = Let(x, int_type(), IntLit(10), x) +# let4 = Let(y, int_type(), IntLit(12), y) +# assert not alpha_eq(let3, let4) + +# def test_let_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# let1 = Let(x, int_type(), IntLit(10), x) +# let2 = Let(y, int_type(), IntLit(10), y) +# assert alpha_eq(let1, let2) + +# def test_ref_eq() -> None: +# r1 = Ref(IntLit(5)) +# r2 = Ref(IntLit(5)) +# assert alpha_eq(r1, r2) + +# def test_ref_not_eq() -> None: +# r1 = Ref(IntLit(5)) +# r2 = Ref(FloatLit(3.5)) +# r3 = Ref(r1) +# assert not alpha_eq(r1, r2) +# assert not alpha_eq(r1, r3) +# assert not alpha_eq(r2, r3) + +# def test_val_ref_eq() -> None: +# vr1 = ReadRef(Ref(IntLit(35))) +# vr2 = ReadRef(Ref(Tuple([IntLit(12), FloatLit(2.5)]))) +# assert alpha_eq(vr1, vr1) +# assert alpha_eq(vr2, vr2) + +# def test_val_ref_not_eq() -> None: +# vr1 = ReadRef(Ref(IntLit(5))) +# vr2 = ReadRef(Ref(vr1)) +# vr3 = ReadRef(Ref(FloatLit(5.0))) +# assert not alpha_eq(vr1, vr2) +# assert not alpha_eq(vr1, vr3) +# assert not alpha_eq(vr2, vr3) + +# def test_set_ref_eq() -> None: +# sr1 = WriteRef(Ref(FloatLit(5.0)), FloatLit(6.0)) +# sr2 = WriteRef(Ref(Tuple([IntLit(3), BoolLit(False)])), +# Tuple([IntLit(5), BoolLit(True)])) +# assert alpha_eq(sr1, sr1) +# assert alpha_eq(sr2, sr2) + +# def test_set_ref_not_eq() -> None: +# r1 = Ref(FloatLit(5.0)) +# r2 = Ref(IntLit(5)) +# r3 = Ref(IntLit(6)) + +# assert not alpha_eq(WriteRef(r1, FloatLit(6.0)), +# WriteRef(r2, IntLit(6))) +# assert not alpha_eq(WriteRef(r2, IntLit(6)), WriteRef(r2, IntLit(7))) +# assert not alpha_eq(WriteRef(r2, IntLit(7)), WriteRef(r3, IntLit(7))) + +# # Type alpha-equality tests + +# def test_base_type_eq() -> None: +# assert alpha_eq(IntType(32), IntType(32)) +# assert alpha_eq(BoolType(), BoolType()) +# assert alpha_eq(FloatType(32), FloatType(32)) + +# def test_tensor_type_eq() -> None: +# tt1 = TensorType( +# IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) +# tt2 = TensorType( +# FloatType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) +# assert alpha_eq(tt1, tt1) +# assert alpha_eq(tt2, tt2) + +# def test_tensor_type_not_eq() -> None: +# tt1 = TensorType( +# IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) +# tt2 = TensorType( +# FloatType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) +# tt3 = TensorType( +# IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) +# assert not alpha_eq(tt1, tt2) +# assert not alpha_eq(tt1, tt3) + +# def test_ref_type_eq() -> None: +# rt1 = RefType(int_type()) +# rt2 = RefType(float_type()) +# assert alpha_eq(rt1, rt1) +# assert alpha_eq(rt2, rt2) + +# def test_ref_type_not_eq() -> None: +# rt1 = RefType(int_type()) +# rt2 = RefType(float_type()) +# assert not alpha_eq(rt1, rt2) + +# def test_product_type_eq() -> None: +# pt1 = TupleType([int_type(), RefType(float_type())]) +# pt2 = TupleType([float_type(), float_type(), int_type()]) +# assert alpha_eq(pt1, pt1) +# assert alpha_eq(pt2, pt2) + +# def test_product_type_not_eq() -> None: +# pt1 = TupleType([int_type(), int_type()]) +# pt2 = TupleType([int_type(), int_type(), float_type()]) +# pt3 = TupleType([bool_type(), float_type()]) +# assert not alpha_eq(pt1, pt2) +# assert not alpha_eq(pt1, pt3) + +# def test_type_id_eq() -> None: +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id2", Kind.BaseType) +# id3 = TypeParam("id2", Kind.Type) + +# assert alpha_eq(id1, id1) +# assert alpha_eq(id2, id2) +# assert alpha_eq(id3, id3) + +# def test_type_id_not_eq() -> None: +# # name is just a hint, we use pointer equality as the rule +# # (unless there is a quantifier to give context) +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id1", Kind.Shape) +# id3 = TypeParam("id3", Kind.BaseType) + +# assert not alpha_eq(id1, id2) +# assert not alpha_eq(id1, id3) + +# def test_arrow_type_eq() -> None: +# ar1 = TypeArrow([int_type()], bool_type()) +# ar2 = TypeArrow([int_type(), int_type()], TupleType([])) +# assert alpha_eq(ar1, ar1) +# assert alpha_eq(ar2, ar2) + +# def test_arrow_type_not_eq() -> None: +# t1 = int_type() +# t2 = bool_type() +# t3 = [int_type(), bool_type()] + +# assert not alpha_eq(TypeArrow([t1], t2), TypeArrow([t1], t1)) +# assert not alpha_eq(TypeArrow(t3, t1), TypeArrow([t2], t1)) +# assert not alpha_eq(TypeArrow([t1], TypeArrow([t1], t1)), +# TypeArrow([t1], t1)) + +# def test_type_quantifier_eq() -> None: +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id2", Kind.Shape) +# tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) +# tq2 = TypeQuantifier(id2, TensorType(IntType(32), id2)) + +# assert alpha_eq(tq1, tq1) +# assert alpha_eq(tq1, tq2) + +# def test_nested_type_quantifier_eq() -> None: +# id1 = TypeParam("id1", Kind.BaseType) +# id2 = TypeParam("id2", Kind.Shape) +# id3 = TypeParam("id3", Kind.BaseType) +# id4 = TypeParam("id4", Kind.Shape) +# tq1 = TypeQuantifier(id1, TypeQuantifier(id2, TensorType(id1, id2))) +# tq2 = TypeQuantifier(id3, TypeQuantifier(id4, TensorType(id3, id4))) + +# assert alpha_eq(tq1, tq1) +# assert alpha_eq(tq1, tq2) + +# def test_type_quantifier_not_eq() -> None: +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id2", Kind.BaseType) +# id3 = TypeParam("id3", Kind.Shape) + +# tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) +# tq2 = TypeQuantifier(id2, TensorType(id2, ShapeSeq([ShapeSingleton(3)]))) +# tq3 = TypeQuantifier(id1, TensorType(IntType(32), id3)) +# tq4 = TypeQuantifier(id1, TensorType(FloatType(32), id1)) + +# assert not alpha_eq(tq1, tq2) +# assert not alpha_eq(tq1, tq3) +# assert not alpha_eq(tq1, tq4) +# assert not alpha_eq(tq2, tq3) +# assert not alpha_eq(tq2, tq4) + +# def test_shape_singleton_eq() -> None: +# single1 = ShapeSingleton(10) +# single2 = ShapeSingleton(10) + +# assert alpha_eq(single1, single1) +# assert alpha_eq(single1, single2) + +# def test_shape_singelton_not_eq() -> None: +# single1 = ShapeSingleton(10) +# single2 = ShapeSingleton(11) + +# assert not alpha_eq(single1, single2) + +# def test_shape_attr_eq() -> None: +# attr1 = ShapeAttr("x") +# attr2 = ShapeAttr("x") + +# assert alpha_eq(attr1, attr1) +# assert alpha_eq(attr1, attr2) + +# def test_shape_attr_not_eq() -> None: +# id1 = "x" +# id2 = "y" +# attr1 = ShapeAttr(id1) +# attr2 = ShapeAttr(id2) + +# assert not alpha_eq(attr1, attr2) + +# def test_shape_seq_eq() -> None: +# empty = ShapeSeq([]) +# seq1 = ShapeSeq([ShapeSingleton(5)]) +# seq2 = ShapeSeq([ShapeSingleton(5)]) + +# assert alpha_eq(empty, empty) +# assert alpha_eq(seq1, seq2) + +# def test_shape_seq_not_eq() -> None: +# empty = ShapeSeq([]) +# seq = ShapeSeq([ShapeSingleton(5)]) +# single = ShapeSingleton(5) + +# assert not alpha_eq(empty, seq) +# assert not alpha_eq(seq, single) + +# def test_shape_projection_eq() -> None: +# proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) +# proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + +# assert alpha_eq(proj1, proj2) + +# def test_shape_projection_not_eq() -> None: +# proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) +# proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 1) +# proj3 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 0) +# proj4 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 1) + +# assert not alpha_eq(proj1, proj2) +# assert not alpha_eq(proj1, proj3) +# assert not alpha_eq(proj1, proj4) +# assert not alpha_eq(proj2, proj3) +# assert not alpha_eq(proj2, proj4) +# assert not alpha_eq(proj3, proj4) + +# def test_shape_binary_op_eq() -> None: +# empty = ShapeSeq([]) +# single = ShapeSingleton(5) +# seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + +# op1 = ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty) +# op2 = ShapeBinaryOp(ShapeOp.SHSUB, single, single) +# op3 = ShapeBinaryOp(ShapeOp.SHMUL, seq, seq) +# op4 = ShapeBinaryOp(ShapeOp.SHDIV, seq, seq) + +# assert alpha_eq(op1, op1) +# assert alpha_eq(op2, op2) +# assert alpha_eq(op3, op3) +# assert alpha_eq(op4, op4) + +# def test_shape_binary_op_not_eq() -> None: +# empty = ShapeSeq([]) +# single = ShapeSingleton(5) +# seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + +# assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), empty) +# assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHMUL, seq, ShapeSingleton(1)), seq) +# assert not alpha_eq( +# ShapeBinaryOp(ShapeOp.SHPLUS, single, single), +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([single]), +# ShapeSeq([single]))) +# assert not alpha_eq( +# ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), +# ShapeBinaryOp(ShapeOp.SHSUB, empty, empty)) +# assert not alpha_eq( +# ShapeBinaryOp(ShapeOp.SHMUL, empty, empty), +# ShapeBinaryOp(ShapeOp.SHDIV, empty, empty)) + +# def test_shape_nested_in_quantifier() -> None: +# b1 = TypeParam("b", Kind.BaseType) +# x1 = TypeParam("x", Kind.Shape) +# y1 = TypeParam("y", Kind.Shape) + +# b2 = TypeParam("b", Kind.BaseType) +# x2 = TypeParam("x", Kind.Shape) +# y2 = TypeParam("y", Kind.Shape) + +# b3 = TypeParam("b", Kind.BaseType) +# x3 = TypeParam("x", Kind.Shape) +# y3 = TypeParam("y", Kind.Shape) + +# tq1 = nest_quantifiers( +# [b1, x1, y1], +# TypeArrow( +# [TensorType(b1, x1), TensorType(b1, y2)], +# TensorType( +# b1, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x1, ShapeProjection(y1, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq2 = nest_quantifiers( +# [b2, x2, y2], +# TypeArrow( +# [TensorType(b2, x2), TensorType(b2, y2)], +# TensorType( +# b2, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x2, ShapeProjection(y2, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# # different attr, var order, position, and constant +# tq3 = nest_quantifiers( +# [b3, x3, y3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x3, ShapeProjection(y3, 1), +# ShapeSingleton(4), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq4 = nest_quantifiers( +# [b3, x3, y3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x3, ShapeProjection(y3, 2), +# ShapeSingleton(5), ShapeAttr("att2")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq5 = nest_quantifiers( +# [b3, x3, y3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHMUL, +# ShapeSeq([x3, ShapeProjection(y3, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq6 = nest_quantifiers( +# [b3, y3, x3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x3, ShapeProjection(y3, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# assert alpha_eq(tq1, tq2) +# assert not alpha_eq(tq1, tq3) +# assert not alpha_eq(tq2, tq3) +# assert not alpha_eq(tq1, tq4) +# assert not alpha_eq(tq2, tq4) +# assert not alpha_eq(tq1, tq5) +# assert not alpha_eq(tq2, tq5) +# assert not alpha_eq(tq1, tq6) +# assert not alpha_eq(tq2, tq6) diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index b2ed075ca3de..065a91f0abbe 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -2,23 +2,23 @@ Test the type unifier, which solves systems of equations between incomplete types. """ -import tvm.relay.ir +from tvm.relay import ir from tvm.relay.unifier import UnionFind, TypeUnifier import tvm.relay.make as mk def test_insert_and_find(): - uf = UnionFind() - v1 = mk.TypeVar(ir.Kind.Type) - v2 = mk.TypeVar(ir.Kind.Type) + uf = mk.UnionFind()() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) uf.insert(v2) assert uf.find(v1) == v1 assert uf.find(v2) == v2 def test_insert_error(): - uf = UnionFind() - v1 = mk.TypeVar(ir.Kind.Type) - v2 = mk.TypeVar(ir.Kind.Type) + uf = mk.UnionFind()() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) try: uf.find(v2) @@ -27,10 +27,10 @@ def test_insert_error(): return def test_unify(): - uf = UnionFind() - v1 = mk.TypeVar(ir.Kind.Type) - v2 = mk.TypeVar(ir.Kind.Type) - v3 = mk.TypeVar(ir.Kind.Type) + uf = mk.UnionFind()() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) + v3 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) uf.insert(v2) uf.insert(v3) @@ -49,8 +49,8 @@ def test_unify(): assert uf.find(v3) == new_rep def test_unify_multiple_levels(): - uf = UnionFind() - v = [TypeVar(ir.Kind.Type) for _ in range(9)] + uf = mk.UnionFind()() + v = [mk.IncompleteType(ir.Kind.Type) for _ in range(9)] for var in v: uf.insert(var) uf.unify(v[0], v[1]) @@ -85,7 +85,7 @@ def test_unify_multiple_levels(): # and now we will test the type unifier which will fill in holes # between type equalities by the process of unification. def unify_types(t1, t2): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() return unifier.unify(t1, t2) # TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work @@ -120,8 +120,8 @@ def test_unify_concrete_type_arrow(): assert unified == arr1 def test_unify_type_arrow_with_holes(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.unify(v1, bool_type()) arr1 = TypeArrow([int_type()], bool_type()) @@ -129,7 +129,7 @@ def test_unify_type_arrow_with_holes(): unified = unifier.unify(arr1, arr2) assert unified == arr1 - v2 = TypeVar(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v2) unifier.unify(v2, int_type()) arr3 = TypeArrow([v2], bool_type()) @@ -161,10 +161,10 @@ def test_unify_basetype_with_quantifier_error(): return def test_unify_typevars_with_each_other(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) - v3 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) + v3 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.insert(v3) @@ -175,10 +175,10 @@ def test_unify_typevars_with_each_other(): assert (new_unified == v1 or new_unified == v2 or new_unified == v3) def test_unify_typevars_with_basetype(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() bt = BoolType() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.BaseType) + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.insert(v2) unified1 = unifier.unify(v1, bt) @@ -187,10 +187,10 @@ def test_unify_typevars_with_basetype(): assert unified2 == bt def test_unify_compatible_typevars(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() bt = BoolType() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.BaseType) + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, bt) @@ -201,9 +201,9 @@ def test_unify_compatible_typevars(): assert unified == bt def test_unify_incompatible_typevars(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) bt = bool_type() tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) unifier.insert(v1) @@ -218,16 +218,16 @@ def test_unify_incompatible_typevars(): return def test_unify_typevar_with_quantifier(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) - v1 = TypeVar(ir.Kind.BaseType) + v1 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unified = unifier.unify(v1, tq) assert unified == tq def test_unify_typevars_inside_concrete_quantifier(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) @@ -312,8 +312,8 @@ def test_unify_products_reject_member(): return def test_unify_products_typevar(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) bt = bool_type() pt1 = TupleType([bt, bt]) pt2 = TupleType([v1, bt]) @@ -344,22 +344,22 @@ def test_unify_ref_reject_inner(): return def test_subst_basetype(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() bt = BoolType() assert bt == unifier.subst(bt) def test_subst_simple_hole(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) bt = BoolType() unifier.insert(v1) unifier.unify(v1, bt) assert unifier.subst(v1) == bt def test_subst_typevar_for_typevar(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) @@ -367,14 +367,14 @@ def test_subst_typevar_for_typevar(): assert unifier.subst(v1) == v2 def test_subst_concrete_arrow(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() arr1 = TypeArrow([int_type()], int_type()) assert unifier.subst(arr1) == arr1 def test_subst_arrow_with_holes(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, int_type()) @@ -384,17 +384,17 @@ def test_subst_arrow_with_holes(): assert unifier.subst(arr1) == arr2 def test_subst_concrete_quantifier(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) unifier.insert(v1) unifier.unify(v1, tq) assert unifier.subst(v1) == tq def test_subst_quantifier_with_holes(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) intty = int_type() tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) @@ -406,16 +406,16 @@ def test_subst_quantifier_with_holes(): assert unifier.subst(v1) == tq2 def test_subst_concrete_tensor(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) tt = TensorType(BoolType(), make_shape([1, 2, 3])) unifier.unify(v1, tt) assert unifier.subst(v1) == tt def test_subst_concrete_product(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) bt = bool_type() pt = TupleType([bt, bt]) @@ -423,10 +423,10 @@ def test_subst_concrete_product(): assert unifier.subst(v1) == pt def test_subst_product_with_holes(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) - v3 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) + v3 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.insert(v3) @@ -441,13 +441,13 @@ def test_subst_product_with_holes(): assert unifier.subst(v1) == pt2 def test_subst_concrete_ref(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() rt = RefType(bool_type()) assert unifier.subst(rt) == rt def test_subst_ref_with_hole(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.unify(v1, bool_type()) @@ -456,9 +456,9 @@ def test_subst_ref_with_hole(): assert unifier.subst(rt1) == rt2 def test_typevar_on_lhs(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.Type) bt = bool_type() tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) unifier.insert(v1) From ff49fdb82315b9c216bc9492f895f08958ebef2c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 00:46:19 -0700 Subject: [PATCH 23/88] Remove tests for ommitted features and fix Remove the tests for features we don't currently support, and fix the tests which were left. --- include/tvm/relay/expr_visitor.h | 6 +- include/tvm/relay/type.h | 12 + python/tvm/relay/ir_builder.py | 23 +- python/tvm/relay/make.py | 41 +++ python/tvm/relay/type.py | 16 +- src/relay/compiler/alpha_eq.cc | 16 +- src/relay/compiler/type_visitor.h | 115 ++++--- src/relay/compiler/typechecker.cc | 2 +- src/relay/compiler/unifier.cc | 126 +++---- src/relay/compiler/unifier.h | 2 +- src/relay/type.cc | 38 +++ tests/python/relay/test_unifier.py | 509 +++++++++++++++-------------- 12 files changed, 508 insertions(+), 398 deletions(-) diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index d7ac1465f70a..721fa531a7e3 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -7,13 +7,13 @@ #ifndef TVM_RELAY_EXPR_VISITOR_H_ #define TVM_RELAY_EXPR_VISITOR_H_ -#include "expr_functor.h" +#include "tvm/relay/expr_functor.h" namespace tvm { namespace relay { template -class ExprVisitor : public ExprFunctor { +class ExprVisitor : public ::tvm::relay::ExprFunctor { public: void VisitExpr_(const LocalVarNode* op, Args... args) override { return; } @@ -62,7 +62,7 @@ class ExprVisitor : public ExprFunctor { }; template -class ExprFVisitor : public ExprFunctor { +class ExprFVisitor : public ::tvm::relay::ExprFunctor { public: Expr VisitExpr_(const LocalVarNode* op, Args... args) override { return GetRef(op); diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 4eeb42168d68..07b047471aba 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -83,6 +83,18 @@ class TensorTypeNode : public BaseTensorTypeNode { TVM_DLL static TensorType make(Array shape, DataType dtype); + /*! \brief Constructing an unsigned integer type */ + TVM_DLL static TensorType Int(int bits, int lanes = 1); + + /*! \brief Constructing an unsigned integer type */ + TVM_DLL static TensorType UInt(int bits, int lanes = 1); + + /*! \brief Construct a floating-point type */ + TVM_DLL static TensorType Float(int bits, int lanes = 1); + + /*1 \brief Construct a boolean type */ + TVM_DLL static TensorType Bool(int lanes = 1); + static constexpr const char* _type_key = "relay.TensorType"; TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); }; diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 3c842e480c70..8fa9b789f53c 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -100,16 +100,23 @@ def get(self): return _mk_let(bindings, self.ret_value) -# def int_type(): -# return TensorType(IntType(32), ShapeSeq([])) +def bool_dtype(): + return 'uint1' -# def float_type(): -# return TensorType(FloatType(32), ShapeSeq([])) +def int_dtype(): + return 'uint1' -# def bool_type(): -# return TensorType(BoolType(), ShapeSeq([])) +def int_type(bits=32, lanes=1): + return mk.IntType(bits, lanes) -# def make_shape(dims): -# return ShapeSeq([ShapeSingleton(dim) for dim in dims]) +def uint_type(bits=32, lanes=1): + return mk.UIntType(bits, lanes) +def float_type(bits=32, lanes=1): + return mk.FloatType(bits, lanes) +def bool_type(lanes=1): + return mk.BoolType(lanes) + +def func_type(args, ret_type, type_params=[], type_constraints=[]): + return mk.FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py index a2b87f2700af..236e2f6af596 100644 --- a/python/tvm/relay/make.py +++ b/python/tvm/relay/make.py @@ -1,4 +1,5 @@ from . import _make +from . import ir # Base Constructors Span = _make.Span @@ -8,6 +9,43 @@ TypeParam = _make.TypeParam FuncType = _make.FuncType +# Types +def IntType(bits: int, lanes: int=1) -> ir.Type: + """Constructs a integer base type. + + :param bits: The bit width of the integer type. + :param lanes: The number of vector elements for this datatype. + + """ + return _make.IntType(bits, lanes) + + +def UIntType(bits: int, lanes: int=1) -> ir.Type: + """Constructs a unsigned integer base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.UIntType(bits, lanes) + + +def FloatType(bits: int, lanes: int=1) -> ir.Type: + """Constructs a floating point base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.FloatType(bits, lanes) + + +def BoolType(lanes: int =1) -> ir.Type: + """Constructs a boolean base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.BoolType(lanes) + # Expr Constructors Constant = _make.Constant Tuple = _make.Tuple @@ -23,3 +61,6 @@ # Unifier UnionFind = _make.UnionFind TypeUnifier = _make.TypeUnifier + +# Utility Functionality @TODO(jroesch): move to another location +_type_alpha_eq = _make._type_alpha_eq diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index 2790b546cfe5..a04089792282 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -4,10 +4,24 @@ from enum import IntEnum from .base import Span, NodeBase, register_relay_node from tvm import expr +# TODO(@jroesch): move me +from ._make import _type_alpha_eq class Type(NodeBase): """The base type for all Relay types.""" - pass + + def __eq__(self, other) -> bool: + """Compares two Relay types for structural equivalence using + alpha equivalence. + """ + return bool(_type_alpha_eq(self, other)) + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + def same_as(self, other) -> bool: + """Compares two Relay types by referential equality.""" + return super().__eq__(other) @register_relay_node class TensorType(Type): diff --git a/src/relay/compiler/alpha_eq.cc b/src/relay/compiler/alpha_eq.cc index 4b8e904bf29e..688a93ae73fc 100644 --- a/src/relay/compiler/alpha_eq.cc +++ b/src/relay/compiler/alpha_eq.cc @@ -33,14 +33,14 @@ struct TypeAlphaEq : TypeVisitor { } } -// void VisitType_(const TypeVarNode *bt1, const Type &t2) override { -// if (const TypeVarNode *bt2 = t2.as()) { -// equal = equal && bt1 == bt2; -// return; -// } else { -// equal = false; -// } -// } + void VisitType_(const IncompleteTypeNode *bt1, const Type &t2) override { + if (const IncompleteTypeNode *bt2 = t2.as()) { + equal = equal && bt1 == bt2; + return; + } else { + equal = false; + } + } void VisitType_(const TypeParamNode *ti1, const Type &t2) override { if (const TypeParamNode *ti2 = t2.as()) { diff --git a/src/relay/compiler/type_visitor.h b/src/relay/compiler/type_visitor.h index 5ae100a8de6d..60ae810a6b96 100644 --- a/src/relay/compiler/type_visitor.h +++ b/src/relay/compiler/type_visitor.h @@ -18,35 +18,34 @@ namespace relay { * We recursively visit each type contained inside the visitor. */ template -struct TypeVisitor : TypeFunctor { - // void VisitType_(const TypeVarNode* op, Args... args) override {} +struct TypeVisitor : ::tvm::relay::TypeFunctor { void VisitType_(const TypeParamNode* op, Args... args) override {} void VisitType_(const FuncTypeNode* op, Args... args) override { - // this->VisitType(op->id, args...); + // fix me handle poly + // this->VisitType(op->var, args...); // this->VisitType(op->boundType, args...); - // for (auto arg_type : op->arg_types) { - // this->VisitType(arg_type, args...); - // } - // this->VisitType(op->ret_type, args...); + for (auto arg_type : op->arg_types) { + this->VisitType(arg_type, args...); + } + this->VisitType(op->ret_type, args...); } - void VisitType_(const TensorTypeNode* op, Args... args) override { - // this->VisitType(op->dtype, args...); - // this->VisitType(op->shape, args...); - } + void VisitType_(const TensorTypeNode* op, Args... args) override {} -// void VisitType_(const TupleTypeNode* op, Args... args) override { -// for (const Type& t : op->fields) { -// this->VisitType(t, args...); -// } -// } + // void VisitType_(const TupleTypeNode* op, Args... args) override { + // for (const Type& t : op->fields) { + // this->VisitType(t, args...); + // } + // } -// void VisitType_(const TypeCallNode* op, Args... args) override { -// for (const Type& t : op->args) { -// this->VisitType(t, args...); -// } -// } + void VisitType_(const TypeCallNode* op, Args... args) override { + this->VisitType(op->func, args...); + + for (const Type& t : op->args) { + this->VisitType(t, args...); + } + } void VisitType_(const TypeFunctionNode* op, Args... args) override {} void VisitType_(const IncompleteTypeNode* op, Args... args) override {} @@ -60,48 +59,46 @@ struct TypeFVisitor : TypeFunctor { } Type VisitType_(const TypeParamNode* op) override { - return GetRef(op); + return GetRef(op); } -// Type VisitType_(const TypeArrowNode* op) override { -// std::vector args; -// for (auto arg_type : op->arg_types) { -// args.push_back(VisitType(arg_type)); -// } -// return TypeArrowNode::make(tvm::Array(args), VisitType(op->ret_type)); -// } - -// Type VisitType_(const TypeQuantifierNode* op) override { -// auto new_id = this->VisitType(op->id); -// if (const TypeParamNode* tin = new_id.as()) { -// return TypeQuantifierNode::make(GetRef(tin), -// this->VisitType(op->boundType)); -// } else { -// throw dmlc::Error("Cannot quantify something that is not a type ID"); -// } -// } - -// Type VisitType_(const TupleTypeNode* op) override { -// std::vector new_fields; -// for (const Type& t : op->fields) { -// new_fields.push_back(this->VisitType(t)); -// } -// return TupleTypeNode::make(new_fields); -// } - -// Type VisitType_(const TypeCallNode* op) override { -// auto func = this->VisitType(op->func); -// std::vector new_args; -// for (const Type& t : op->args) { -// new_args.push_back(this->VisitType(t)); -// } -// return TypeCallNode::make(func, new_args); -// } + Type VisitType_(const FuncTypeNode* op) override { + // auto new_id = this->VisitType(op->var); + // if (const TypeParamNode* tin = new_id.as()) { + // return TypeQuantifierNode::make(GetRef(tin), + // this->VisitType(op->boundType)); + + std::vector args; + for (auto arg_type : op->arg_types) { + args.push_back(VisitType(arg_type)); + } + + return FuncTypeNode::make(tvm::Array(args), + VisitType(op->ret_type), {}, {}); // fix me + } + + // Type VisitType_(const TupleTypeNode* op) override { + // std::vector new_fields; + // for (const Type& t : op->fields) { + // new_fields.push_back(this->VisitType(t)); + // } + // return TupleTypeNode::make(new_fields); + // } + + Type VisitType_(const TypeCallNode* op) override { + auto func = this->VisitType(op->func); + std::vector new_args; + for (const Type& t : op->args) { + new_args.push_back(this->VisitType(t)); + } + return TypeCallNode::make(func, new_args); + } + Type VisitType_(const IncompleteTypeNode* op) override { - return GetRef(op); + return GetRef(op); } -}; + }; } // namespace relay -} // namespace tvm +} // namespace relay #endif // TVM_RELAY_TYPE_VISITOR_H_ diff --git a/src/relay/compiler/typechecker.cc b/src/relay/compiler/typechecker.cc index c1f7b7f88765..e16481b7f9e0 100644 --- a/src/relay/compiler/typechecker.cc +++ b/src/relay/compiler/typechecker.cc @@ -764,7 +764,7 @@ TVM_REGISTER_API("relay._make.IncompleteType") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const IncompleteTypeNode *node, tvm::IRPrinter *p) { - p->stream << "IncompleteTypeNode(" << node->kind << ", " << &node << ")"; + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; }); } // namespace relay diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc index bfd3e1a5ff32..ff46e8e863d1 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/compiler/unifier.cc @@ -9,8 +9,8 @@ #include "tvm/relay/compiler/alpha_eq.h" #include "./unifier.h" #include "./type_visitor.h" +#include "./type_subst.h" // #include "tvm/relay/typeck/kindchecker.h" -// #include "tvm/relay/typeck/type_subst.h" namespace tvm { namespace relay { @@ -60,8 +60,6 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { if (const IncompleteTypeNode *pvn1 = parent1.as()) { auto pv1 = GetRef(pvn1); this->uf_map.Set(pv1, parent2); - // path compression: can also set v1 directly - this->uf_map.Set(v1, parent2); return; } @@ -69,8 +67,6 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { if (const IncompleteTypeNode *pvn2 = parent2.as()) { auto pv2 = GetRef(pvn2); this->uf_map.Set(pv2, parent1); - // path compression: can also set v2 directly - this->uf_map.Set(v2, parent1); return; } @@ -84,8 +80,6 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { if (const IncompleteTypeNode *pvn1 = parent1.as()) { auto pv1 = GetRef(pvn1); this->uf_map.Set(pv1, t); - // path compression: can also set v1 directly - this->uf_map.Set(v1, t); return; } @@ -181,6 +175,24 @@ Type TypeUnifierNode::subst(const Type &t) { return ret; } +Type TypeUnifierNode::VisitType(const Type & t1, const Type t2) { + // When the right hand size is a type variable immediately unify. + if (const IncompleteTypeNode *tvn2 = t2.as()) { + return this->unifyWithIncompleteType(t1, GetRef(tvn2)); + } else { + return TypeFunctor::VisitType(t1, t2); + } +} + +Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, const IncompleteType tv2) { + RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; + // Fix unify to return new representative + this->uf->unify(tv2, t1); + auto rep = this->uf->find(tv2); + RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; + return rep; +} + Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { IncompleteType tv1 = GetRef(t1); RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2 @@ -194,11 +206,6 @@ Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { TypeParam ti1 = GetRef(t1); - // for typevars, remap and attempt to unify if already defined - if (const IncompleteTypeNode *tvn2 = rt2.as()) { - return this->unifyWithIncompleteType(ti1, GetRef(tvn2)); - } - // for other type ids, only check equality if (const TypeParamNode *tin2 = rt2.as()) { TypeParam ti2 = GetRef(tin2); @@ -215,75 +222,55 @@ Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { } Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { - return rt2; -// TypeArrow ta1 = GetRef(t1); - -// // for typevar, remap if necessary -// if (const IncompleteTypeNode *tvn2 = rt2.as()) { -// return this->unifyWithIncompleteType(ta1, GetRef(tvn2)); -// } + FuncType ft1 = GetRef(t1); -// // for other arrow, unify arg and ret types -// if (const TypeArrowNode *tan2 = rt2.as()) { -// TypeArrow ta2 = GetRef(tan2); + if (const FuncTypeNode *tan2 = rt2.as()) { + FuncType ft2 = GetRef(tan2); -// if (ta1->arg_types.size() != ta2->arg_types.size()) { -// throw UnificationError("unable to unify functions of different arities"); -// } + if (ft1->type_params.size() != ft2->type_params.size()) { + throw UnificationError("unable to unify functions with differing number of type parameters"); + } -// tvm::Array unified_args; -// for (size_t i = 0; i < ta1->arg_types.size(); i++) { -// unified_args.push_back( -// this->VisitType(ta1->arg_types[i], ta2->arg_types[i])); -// } + if (ft1->type_params.size() != 0) { + throw dmlc::Error("NYI"); + } -// Type unified_ret_type = this->VisitType(ta1->ret_type, ta2->ret_type); -// return TypeArrowNode::make(unified_args, unified_ret_type); -// } + // TypeParam id1 = tq1->id; + // TypeParam id2 = tq2->id; -// throw UnificationError("Unable to unify TypeArrowNode"); -// } + // if (id1->kind != id2->kind) { + // throw UnificationError( + // "Cannot unify quantifiers over ids of different kinds"); + // } -// Type TypeUnifierNode::VisitType_(const TypeQuantifierNode *t1, const Type rt2) { -// TypeQuantifier tq1 = GetRef(t1); + // TypeParam fresh = TypeParamNode::make(id1->name, id1->kind); -// // for typevars, remap and attempt to unify if already defined -// if (const IncompleteTypeNode *tvn2 = rt2.as()) { -// return this->unifyWithIncompleteType(tq1, GetRef(tvn2)); -// } + // auto bt1 = type_subst(tq1->boundType, id1, fresh); + // auto bt2 = type_subst(tq2->boundType, id2, fresh); -// // for other quantifiers, attempt to unify bound types after normalizing -// if (const TypeQuantifierNode *tqn2 = rt2.as()) { -// TypeQuantifier tq2 = GetRef(tqn2); -// TypeParam id1 = tq1->id; -// TypeParam id2 = tq2->id; + // Type unified_bound_type = this->VisitType(bt1, bt2); -// if (id1->kind != id2->kind) { -// throw UnificationError( -// "Cannot unify quantifiers over ids of different kinds"); -// } + if (ft1->arg_types.size() != ft2->arg_types.size()) { + throw UnificationError("unable to unify functions of different arities"); + } -// TypeParam fresh = TypeParamNode::make(id1->name, id1->kind); + tvm::Array unified_args; + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + unified_args.push_back( + this->VisitType(ft1->arg_types[i], ft2->arg_types[i])); + } -// auto bt1 = type_subst(tq1->boundType, id1, fresh); -// auto bt2 = type_subst(tq2->boundType, id2, fresh); + Type unified_ret_type = this->VisitType(ft1->ret_type, ft2->ret_type); -// Type unified_bound_type = this->VisitType(bt1, bt2); -// return TypeQuantifierNode::make(fresh, unified_bound_type); -// } + return FuncTypeNode::make(unified_args, unified_ret_type, {}, {}); + } -// // anything else cannot be unified -// throw UnificationError("Cannot unify TypeQuantifierNode"); + throw UnificationError("unable to unify function types"); } Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { TensorType tt1 = GetRef(t1); - // for typevars, remap and attempt to unify if already defined - if (const IncompleteTypeNode *tvn2 = rt2.as()) { - return this->unifyWithIncompleteType(tt1, GetRef(tvn2)); - } - if (const TensorTypeNode *ttn2 = rt2.as()) { TensorType tt2 = GetRef(ttn2); @@ -360,10 +347,6 @@ Type TypeUnifierNode::VisitType_(const TypeFunctionNode *sen1, const Type t2) { Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { TypeCall ty_call1 = GetRef(tcn1); - if (const IncompleteTypeNode *tvn2 = t2.as()) { - return this->unifyWithIncompleteType(ty_call1, GetRef(tvn2)); - } - if (const TypeCallNode *tcn2 = t2.as()) { Type unified_func = this->VisitType(ty_call1->func, tcn2->func); @@ -385,14 +368,7 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { } } -Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, const IncompleteType tv2) { - RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; - // Fix unify to return new representative - this->uf->unify(tv2, t1); - auto rep = this->uf->find(tv2); - RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; - return rep; -} + TVM_REGISTER_API("relay._make.TypeUnifier") .set_body([](TVMArgs args, TVMRetValue *ret) { diff --git a/src/relay/compiler/unifier.h b/src/relay/compiler/unifier.h index 6788265c90f2..cba96ff02451 100644 --- a/src/relay/compiler/unifier.h +++ b/src/relay/compiler/unifier.h @@ -101,7 +101,7 @@ class TypeUnifierNode : public Node, private: // unify non-typevar with typevar Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); - + Type VisitType(const Type & t1, const Type t2) override; Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; Type VisitType_(const TensorTypeNode* t1, const Type t2) override; Type VisitType_(const TypeParamNode* t1, const Type t2) override; diff --git a/src/relay/type.cc b/src/relay/type.cc index 22d37ea05fda..2b6647a5807e 100644 --- a/src/relay/type.cc +++ b/src/relay/type.cc @@ -6,6 +6,7 @@ #include "tvm/relay/type.h" #include "tvm/ir_functor.h" + namespace tvm { namespace relay { @@ -19,12 +20,49 @@ TensorType TensorTypeNode::make(Array shape, DataType dtype) { return TensorType(n); } +TensorType TensorTypeNode::Int(int bits, int lanes) { + return TensorTypeNode::make({}, HalideIR::Int(bits, lanes)); +} + +TensorType TensorTypeNode::UInt(int bits, int lanes) { + return TensorTypeNode::make({}, HalideIR::UInt(bits, lanes)); +} + +TensorType TensorTypeNode::Float(int bits, int lanes) { + return TensorTypeNode::make({}, HalideIR::Float(bits, lanes)); +} + +TensorType TensorTypeNode::Bool(int lanes) { + return TensorTypeNode::make({}, HalideIR::Bool(lanes)); +} + TVM_REGISTER_API("relay._make.TensorType") .set_body([](TVMArgs args, TVMRetValue *ret) { Array shape = args[0]; *ret = TensorTypeNode::make(shape, args[1]); }); + +TVM_REGISTER_API("relay._make.IntType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TensorTypeNode::Int(args[0], args[1]); + }); + +TVM_REGISTER_API("relay._make.UIntType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TensorTypeNode::UInt(args[0], args[1]); + }); + +TVM_REGISTER_API("relay._make.BoolType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TensorTypeNode::Bool(args[0]); + }); + +TVM_REGISTER_API("relay._make.FloatType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TensorTypeNode::Float(args[0], args[1]); + }); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TensorTypeNode *node, tvm::IRPrinter *p) { diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index 065a91f0abbe..21889faa51ee 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -2,12 +2,16 @@ Test the type unifier, which solves systems of equations between incomplete types. """ +import tvm from tvm.relay import ir from tvm.relay.unifier import UnionFind, TypeUnifier +from tvm.relay.ir_builder import bool_type, uint_type, int_type, float_type, func_type +from tvm.relay import ir_builder as build import tvm.relay.make as mk + def test_insert_and_find(): - uf = mk.UnionFind()() + uf = mk.UnionFind() v1 = mk.IncompleteType(ir.Kind.Type) v2 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) @@ -15,8 +19,9 @@ def test_insert_and_find(): assert uf.find(v1) == v1 assert uf.find(v2) == v2 + def test_insert_error(): - uf = mk.UnionFind()() + uf = mk.UnionFind() v1 = mk.IncompleteType(ir.Kind.Type) v2 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) @@ -26,8 +31,9 @@ def test_insert_error(): except: return + def test_unify(): - uf = mk.UnionFind()() + uf = mk.UnionFind() v1 = mk.IncompleteType(ir.Kind.Type) v2 = mk.IncompleteType(ir.Kind.Type) v3 = mk.IncompleteType(ir.Kind.Type) @@ -48,8 +54,9 @@ def test_unify(): assert uf.find(v2) == new_rep assert uf.find(v3) == new_rep + def test_unify_multiple_levels(): - uf = mk.UnionFind()() + uf = mk.UnionFind() v = [mk.IncompleteType(ir.Kind.Type) for _ in range(9)] for var in v: uf.insert(var) @@ -84,81 +91,92 @@ def test_unify_multiple_levels(): # We have checked that the basic machinery in the UnionFind works # and now we will test the type unifier which will fill in holes # between type equalities by the process of unification. + + def unify_types(t1, t2): unifier = mk.TypeUnifier() return unifier.unify(t1, t2) # TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work + + def test_unify_int(): - intty = IntType(1) + intty = int_type(1) unified = unify_types(intty, intty) assert intty == unified + def test_unify_bool(): - boolty = BoolType() + boolty = bool_type() unified = unify_types(boolty, boolty) assert boolty == unified + def test_unify_float(): - floatty = FloatType(4) + floatty = float_type(4) unified = unify_types(floatty, floatty) assert floatty == unified + def test_unify_incompatible_basetypes(): - bt = BoolType() - intty = IntType(32) + bt = bool_type() + intty = int_type(32) try: unify_types(bt, intty) assert False except: return -def test_unify_concrete_type_arrow(): - arr1 = TypeArrow([int_type()], int_type()) - arr2 = TypeArrow([int_type()], int_type()) + +def test_unify_concrete_func_type(): + arr1 = func_type([int_type()], int_type()) + arr2 = func_type([int_type()], int_type()) unified = unify_types(arr1, arr2) assert unified == arr1 -def test_unify_type_arrow_with_holes(): + +def test_unify_func_type_with_holes(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.unify(v1, bool_type()) - arr1 = TypeArrow([int_type()], bool_type()) - arr2 = TypeArrow([int_type()], v1) + arr1 = func_type([int_type()], bool_type()) + arr2 = func_type([int_type()], v1) unified = unifier.unify(arr1, arr2) assert unified == arr1 v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v2) unifier.unify(v2, int_type()) - arr3 = TypeArrow([v2], bool_type()) + arr3 = func_type([v2], bool_type()) unified = unifier.unify(arr1, arr3) assert unified == arr1 -def test_reject_incompatible_type_arrows(): - arr1 = TypeArrow([int_type()], bool_type()) - arr2 = TypeArrow([int_type(), bool_type()], bool_type()) + +def test_reject_incompatible_func_types(): + arr1 = func_type([int_type()], bool_type()) + arr2 = func_type([int_type(), bool_type()], bool_type()) try: unify_types(arr1, arr2) assert False except: return -def test_unify_concrete_type_quantifiers(): - tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) - tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), int_type()) - unified = unify_types(tq1, tq2) - assert unified == tq1 +# def test_unify_concrete_type_quantifiers(): +# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) +# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), int_type()) +# unified = unify_types(tq1, tq2) +# assert unified == tq1 + +# def test_unify_basetype_with_quantifier_error(): +# bt = bool_type() +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) +# try: +# unify_types(bt, tq) +# assert False +# except: +# return -def test_unify_basetype_with_quantifier_error(): - bt = bool_type() - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) - try: - unify_types(bt, tq) - assert False - except: - return def test_unify_typevars_with_each_other(): unifier = mk.TypeUnifier() @@ -174,11 +192,12 @@ def test_unify_typevars_with_each_other(): new_unified = unifier.unify(v1, v3) assert (new_unified == v1 or new_unified == v2 or new_unified == v3) + def test_unify_typevars_with_basetype(): unifier = mk.TypeUnifier() - bt = BoolType() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.BaseType) + bt = bool_type() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unified1 = unifier.unify(v1, bt) @@ -186,11 +205,12 @@ def test_unify_typevars_with_basetype(): unified2 = unifier.unify(v1, v2) assert unified2 == bt + def test_unify_compatible_typevars(): unifier = mk.TypeUnifier() - bt = BoolType() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.BaseType) + bt = bool_type() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, bt) @@ -200,162 +220,154 @@ def test_unify_compatible_typevars(): unified = unifier.unify(v1, v2) assert unified == bt -def test_unify_incompatible_typevars(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.BaseType) - bt = bool_type() - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) - unifier.insert(v1) - unifier.insert(v2) - unifier.unify(v1, bt) - unifier.unify(v2, tq) - # bt cannot be unified with tq, so unifying v1 and v2 should give an error - try: - unifier.unify(v1, v2) - assert False - except: - return +# def test_unify_incompatible_typevars(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# v2 = mk.IncompleteType(ir.Kind.Type) +# bt = bool_type() +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) +# unifier.insert(v1) +# unifier.insert(v2) +# unifier.unify(v1, bt) +# unifier.unify(v2, tq) +# # bt cannot be unified with tq, so unifying v1 and v2 should give an error +# try: +# unifier.unify(v1, v2) +# assert False +# except: +# return + +# def test_unify_typevar_with_quantifier(): +# unifier = mk.TypeUnifier() +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier.insert(v1) +# unified = unifier.unify(v1, tq) +# assert unified == tq + +# def test_unify_typevars_inside_concrete_quantifier(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier.insert(v1) +# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) +# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) +# unified = unifier.unify(tq1, tq2) +# assert unified == tq2 -def test_unify_typevar_with_quantifier(): - unifier = mk.TypeUnifier() - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) - v1 = mk.IncompleteType(ir.Kind.BaseType) - unifier.insert(v1) - unified = unifier.unify(v1, tq) - assert unified == tq - -def test_unify_typevars_inside_concrete_quantifier(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - unifier.insert(v1) - tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) - tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) - unified = unifier.unify(tq1, tq2) - assert unified == tq2 def test_unify_concrete_tensors(): - bt = BoolType() - shape = make_shape([1, 2, 3]) - tt1 = TensorType(bt, shape) - tt2 = TensorType(bt, shape) + bt = build.bool_dtype() + shape = tvm.convert([1, 2, 3]) + tt1 = mk.TensorType(shape, bt) + tt2 = mk.TensorType(shape, bt) unified = unify_types(tt1, tt2) assert unified == tt1 + def test_unify_tensor_shape_reject(): - bt = BoolType() - shape1 = make_shape([1, 2, 3]) - shape2 = make_shape([2, 3, 4]) - tt1 = TensorType(bt, shape1) - tt2 = TensorType(bt, shape2) + bt = build.bool_dtype() + shape1 = tvm.convert([1, 2, 3]) + shape2 = tvm.convert([2, 3, 4]) + tt1 = mk.TensorType(shape1, bt) + tt2 = mk.TensorType(shape2, bt) try: unify_types(tt1, tt2) assert False except: return + def test_unify_tensor_dtype_reject(): - bt1 = BoolType() - bt2 = IntType(32) - shape = make_shape([1, 2, 3]) - tt1 = TensorType(bt1, shape) - tt2 = TensorType(bt2, shape) + bt1 = build.bool_dtype() + bt2 = build.int_dtype() + shape = tvm.convert([1, 2, 3]) + tt1 = mk.TensorType(shape, bt1) + tt2 = mk.TensorType(shape, bt2) try: unify_types(tt1, tt2) assert False except: return -def test_unify_quantified_tensors(): - x = TypeParam("x", ir.type.Kind.Shape) - y = TypeParam("y", ir.type.Kind.Shape) - tq1 = TypeQuantifier(x, TensorType(BoolType(), x)) - tq2 = TypeQuantifier(y, TensorType(BoolType(), y)) - unified = unify_types(tq1, tq2) - assert unified == tq1 - - a = TypeParam("a", ir.type.Kind.BaseType) - b = TypeParam("b", ir.type.Kind.BaseType) - tq3 = TypeQuantifier(a, TensorType(a, make_shape([1, 2, 3]))) - tq4 = TypeQuantifier(b, TensorType(b, make_shape([1, 2, 3]))) - unified = unify_types(tq3, tq4) - assert unified == tq3 - -def test_unify_concrete_products(): - bt = bool_type() - intty = int_type() - pt1 = TupleType([bt, intty]) - pt2 = TupleType([bt, intty]) - unified = unify_types(pt1, pt2) - assert unified == pt1 - -def test_unify_products_reject_size(): - bt = BoolType() - intty = IntType(32) - pt1 = TupleType([bt, bt, intty]) - pt2 = TupleType([bt, intty]) - try: - unify_types(pt1, pt2) - assert False - except: - return - -def test_unify_products_reject_member(): - bt = BoolType() - intty = IntType(32) - pt1 = TupleType([bt, bt]) - pt2 = TupleType([bt, intty]) - try: - unify_types(pt1, pt2) - assert False - except: - return +# def test_unify_quantified_tensors(): +# x = TypeParam("x", ir.type.Kind.Shape) +# y = TypeParam("y", ir.type.Kind.Shape) +# tq1 = TypeQuantifier(x, mk.TensorType(bool_type(), x)) +# tq2 = TypeQuantifier(y, mk.TensorType(bool_type(), y)) +# unified = unify_types(tq1, tq2) +# assert unified == tq1 + +# a = TypeParam("a", ir.type.Kind.BaseType) +# b = TypeParam("b", ir.type.Kind.BaseType) +# tq3 = TypeQuantifier(a, mk.TensorType(a, make_shape([1, 2, 3]))) +# tq4 = TypeQuantifier(b, mk.TensorType(b, make_shape([1, 2, 3]))) +# unified = unify_types(tq3, tq4) +# assert unified == tq3 + +# def test_unify_concrete_products(): +# bt = bool_type() +# intty = int_type() +# pt1 = TupleType([bt, intty]) +# pt2 = TupleType([bt, intty]) +# unified = unify_types(pt1, pt2) +# assert unified == pt1 + +# def test_unify_products_reject_size(): +# bt = bool_type() +# intty = IntType(32) +# pt1 = TupleType([bt, bt, intty]) +# pt2 = TupleType([bt, intty]) +# try: +# unify_types(pt1, pt2) +# assert False +# except: +# return + +# def test_unify_products_reject_member(): +# bt = bool_type() +# intty = int_type() +# pt1 = TupleType([bt, bt]) +# pt2 = TupleType([bt, intty]) +# try: +# unify_types(pt1, pt2) +# assert False +# except: +# return + +# def test_unify_products_typevar(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# bt = bool_type() +# pt1 = TupleType([bt, bt]) +# pt2 = TupleType([v1, bt]) +# unifier.insert(v1) +# unified = unifier.unify(pt1, pt2) +# assert unified == pt1 + +# def test_unify_quantified_products(): +# x = TypeParam("x", ir.Kind.Type) +# y = TypeParam("y", ir.Kind.Type) +# p1 = TypeQuantifier(x, TupleType([int_type(), x])) +# p2 = TypeQuantifier(y, TupleType([int_type(), y])) +# unified = unify_types(p1, p2) +# assert unified == p1 -def test_unify_products_typevar(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - bt = bool_type() - pt1 = TupleType([bt, bt]) - pt2 = TupleType([v1, bt]) - unifier.insert(v1) - unified = unifier.unify(pt1, pt2) - assert unified == pt1 - -def test_unify_quantified_products(): - x = TypeParam("x", ir.Kind.Type) - y = TypeParam("y", ir.Kind.Type) - p1 = TypeQuantifier(x, TupleType([int_type(), x])) - p2 = TypeQuantifier(y, TupleType([int_type(), y])) - unified = unify_types(p1, p2) - assert unified == p1 - -def test_unify_ref_types(): - r1 = RefType(bool_type()) - r2 = RefType(bool_type()) - assert unify_types(r1, r2) == r1 - -def test_unify_ref_reject_inner(): - r1 = RefType(BoolType()) - r2 = RefType(IntType(32)) - try: - unify_types(r1, r2) - assert False - except: - return def test_subst_basetype(): unifier = mk.TypeUnifier() - bt = BoolType() + bt = bool_type() assert bt == unifier.subst(bt) + def test_subst_simple_hole(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.BaseType) - bt = BoolType() + bt = bool_type() unifier.insert(v1) unifier.unify(v1, bt) assert unifier.subst(v1) == bt + def test_subst_typevar_for_typevar(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.Type) @@ -364,13 +376,26 @@ def test_subst_typevar_for_typevar(): unifier.insert(v2) unifier.unify(v1, v2) - assert unifier.subst(v1) == v2 + assert unifier.subst(v1) == unifier.subst(v2) + + +def test_subst_typevar_for_typevar_comm(): + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) + unifier.insert(v1) + unifier.insert(v2) + + unifier.unify(v2, v1) + assert unifier.subst(v1) == unifier.subst(v2) + def test_subst_concrete_arrow(): unifier = mk.TypeUnifier() - arr1 = TypeArrow([int_type()], int_type()) + arr1 = func_type([int_type()], int_type()) assert unifier.subst(arr1) == arr1 + def test_subst_arrow_with_holes(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.BaseType) @@ -379,93 +404,93 @@ def test_subst_arrow_with_holes(): unifier.insert(v2) unifier.unify(v1, int_type()) unifier.unify(v2, bool_type()) - arr1 = TypeArrow([v1], v2) - arr2 = TypeArrow([int_type()], bool_type()) + arr1 = func_type([v1], v2) + arr2 = func_type([int_type()], bool_type()) assert unifier.subst(arr1) == arr2 -def test_subst_concrete_quantifier(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) - unifier.insert(v1) - unifier.unify(v1, tq) - assert unifier.subst(v1) == tq +# def test_subst_concrete_quantifier(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) +# unifier.insert(v1) +# unifier.unify(v1, tq) +# assert unifier.subst(v1) == tq + +# def test_subst_quantifier_with_holes(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# v2 = mk.IncompleteType(ir.Kind.Type) +# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) +# intty = int_type() +# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) + # unifier.insert(v1) + # unifier.insert(v2) + # unifier.unify(v2, intty) + # unifier.unify(v1, tq1) + # assert unifier.subst(v1) == tq2 -def test_subst_quantifier_with_holes(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) - tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) - intty = int_type() - tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) - - unifier.insert(v1) - unifier.insert(v2) - unifier.unify(v2, intty) - unifier.unify(v1, tq1) - assert unifier.subst(v1) == tq2 def test_subst_concrete_tensor(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) - tt = TensorType(BoolType(), make_shape([1, 2, 3])) + tt = mk.TensorType(tvm.convert([1, 2, 3]), 'uint1') unifier.unify(v1, tt) assert unifier.subst(v1) == tt -def test_subst_concrete_product(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - bt = bool_type() - pt = TupleType([bt, bt]) - unifier.unify(v1, pt) - assert unifier.subst(v1) == pt - -def test_subst_product_with_holes(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) - v3 = mk.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - unifier.insert(v3) - - tt1 = TensorType(IntType(32), ShapeSeq([])) - tt2 = TensorType(FloatType(32), ShapeSeq([])) - pt1 = TupleType([tt1, v2, v3]) - unifier.unify(v2, tt2) - unifier.unify(v3, v2) - unifier.unify(v1, pt1) - pt2 = TupleType([tt1, tt2, tt2]) - assert unifier.subst(v1) == pt2 - -def test_subst_concrete_ref(): - unifier = mk.TypeUnifier() - rt = RefType(bool_type()) - assert unifier.subst(rt) == rt - -def test_subst_ref_with_hole(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - - unifier.unify(v1, bool_type()) - rt1 = RefType(v1) - rt2 = RefType(bool_type()) - assert unifier.subst(rt1) == rt2 - -def test_typevar_on_lhs(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.Type) - bt = bool_type() - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) - unifier.insert(v1) - unifier.insert(v2) - unified1 = unifier.unify(bt, v1) - assert unified1 == bt - unified2 = unifier.unify(tq, v2) - assert unified2 == tq - assert unifier.subst(v1) == bt - assert unifier.subst(v2) == tq +# def test_subst_concrete_product(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# unifier.insert(v1) +# bt = bool_type() +# pt = TupleType([bt, bt]) +# unifier.unify(v1, pt) +# assert unifier.subst(v1) == pt + +# def test_subst_product_with_holes(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# v2 = mk.IncompleteType(ir.Kind.Type) +# v3 = mk.IncompleteType(ir.Kind.Type) +# unifier.insert(v1) +# unifier.insert(v2) +# unifier.insert(v3) + +# tt1 = mk.TensorType(int_type(), tvm.convert([])) +# tt2 = mk.TensorType(FloatType(32), tvm.convert([])) +# pt1 = TupleType([tt1, v2, v3]) +# unifier.unify(v2, tt2) +# unifier.unify(v3, v2) +# unifier.unify(v1, pt1) +# pt2 = TupleType([tt1, tt2, tt2]) +# assert unifier.subst(v1) == pt2 + +# def test_subst_concrete_ref(): +# unifier = mk.TypeUnifier() +# rt = RefType(bool_type()) +# assert unifier.subst(rt) == rt + +# def test_subst_ref_with_hole(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# unifier.insert(v1) + +# unifier.unify(v1, bool_type()) +# rt1 = RefType(v1) +# rt2 = RefType(bool_type()) +# assert unifier.subst(rt1) == rt2 + +# def test_typevar_on_lhs(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# v2 = mk.IncompleteType(ir.Kind.Type) +# bt = bool_type() +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) +# unifier.insert(v1) +# unifier.insert(v2) +# unified1 = unifier.unify(bt, v1) +# assert unified1 == bt +# unified2 = unifier.unify(tq, v2) +# assert unified2 == tq +# assert unifier.subst(v1) == bt +# assert unifier.subst(v2) == tq From 44b32166f3dea90efdc8d9aacfa115a16901479f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 01:01:31 -0700 Subject: [PATCH 24/88] Start refactoring type checker Introduce both Environment and type inference Python interfaces for testing. --- .../compiler/{typechecker.h => type_infer.h} | 10 +- python/tvm/relay/_env.py | 5 + python/tvm/relay/_env.pyi | 18 ++++ python/tvm/relay/_type_infer.py | 5 + python/tvm/relay/_type_infer.pyi | 6 ++ python/tvm/relay/env.py | 98 +++++++++++++++++++ python/tvm/relay/ir_builder.py | 4 +- python/tvm/relay/type_infer.py | 6 ++ .../{typechecker.cc => type_infer.cc} | 28 +++--- tests/python/relay/test_typechecker.py | 17 ++++ 10 files changed, 177 insertions(+), 20 deletions(-) rename include/tvm/relay/compiler/{typechecker.h => type_infer.h} (70%) create mode 100644 python/tvm/relay/_env.py create mode 100644 python/tvm/relay/_env.pyi create mode 100644 python/tvm/relay/_type_infer.py create mode 100644 python/tvm/relay/_type_infer.pyi create mode 100644 python/tvm/relay/env.py create mode 100644 python/tvm/relay/type_infer.py rename src/relay/compiler/{typechecker.cc => type_infer.cc} (98%) create mode 100644 tests/python/relay/test_typechecker.py diff --git a/include/tvm/relay/compiler/typechecker.h b/include/tvm/relay/compiler/type_infer.h similarity index 70% rename from include/tvm/relay/compiler/typechecker.h rename to include/tvm/relay/compiler/type_infer.h index c69aba3c1e71..4c16defe977f 100644 --- a/include/tvm/relay/compiler/typechecker.h +++ b/include/tvm/relay/compiler/type_infer.h @@ -1,8 +1,10 @@ /*! - * Copyright (c) 2017 by Contributors - * \file tvm/relay/typechecker.h - * \brief Type check a Relay program producing a type checked program - * with its checked_type field populated and incomplete types resolved. + * Copyright (c) 2018 by Contributors + * \file tvm/relay/type_infer.h + * \brief Perform type inference and checking on Relay programs. + * + * The pass produces a new expression with its checked_type + * field populated and incomplete types resolved. */ #ifndef TVM_RELAY_COMPILER_TYPECHECKER_H_ #define TVM_RELAY_COMPILER_TYPECHECKER_H_ diff --git a/python/tvm/relay/_env.py b/python/tvm/relay/_env.py new file mode 100644 index 000000000000..25b8715a7816 --- /dev/null +++ b/python/tvm/relay/_env.py @@ -0,0 +1,5 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable +"""The interface to the Environment exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay._env", __name__) diff --git a/python/tvm/relay/_env.pyi b/python/tvm/relay/_env.pyi new file mode 100644 index 000000000000..d14e726e5443 --- /dev/null +++ b/python/tvm/relay/_env.pyi @@ -0,0 +1,18 @@ +from typing import Union, Tuple, Dict, List +from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId +from relay.ir import ShapeExtension, Operator, Defn + +class Environment(NodeBase): ... + +def Environment_add(self: Environment, func: GlobalId) -> None: ... +def Environment_global_id(self: Environment, name: str) -> GlobalId: ... +def Environment_operator_id(self: Environment, name: str) -> OperatorId: ... +def Environment_lookup_global(self: Environment, id: GlobalId) -> Item: ... +def Environment_lookup_operator(self: Environment, id: OperatorId) -> Item: ... +def Environment_remove_global(self: Environment, id: GlobalId) -> Item: ... +def Environment_add_source(self: Environment, file_name: str, source: str) -> FileId: ... +def Environment_report_error(self: Environment, message: str, span: Span) -> None: ... +def Environment_display_errors(self: Environment) -> None: ... +def Environment_register_shape_ext(self: Environment, shape_ext: ShapeExtension) -> None: ... +def Environment_get_operators(self: Environment) -> List[Operator]: ... +def Environment_get_defns(self: Environment) -> List[Defn]: ... diff --git a/python/tvm/relay/_type_infer.py b/python/tvm/relay/_type_infer.py new file mode 100644 index 000000000000..7213769a4164 --- /dev/null +++ b/python/tvm/relay/_type_infer.py @@ -0,0 +1,5 @@ +"""FFI exposing the Relay type inference and checking.""" + +from tvm._ffi.function import _init_api + +_init_api("relay._type_infer", __name__) diff --git a/python/tvm/relay/_type_infer.pyi b/python/tvm/relay/_type_infer.pyi new file mode 100644 index 000000000000..1bb42ab854c2 --- /dev/null +++ b/python/tvm/relay/_type_infer.pyi @@ -0,0 +1,6 @@ +from .env import Environment +from . import ir + +def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... +def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... +def _get_checked_type(expr: ir.Expr) -> ir.Type: ... diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py new file mode 100644 index 000000000000..9bd63476f1fb --- /dev/null +++ b/python/tvm/relay/env.py @@ -0,0 +1,98 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import +"""A global environment storing everything needed to interpret or compile a Realy program.""" +from typing import Union, List +from relay.ir import register_relay_node, NodeBase +from relay.ir import GlobalId, OperatorId, Item, FileId, Span, ShapeExtension +from relay.ir import Operator, Defn +from relay._env import * +import tvm + +# Move me to C++ if possible. +__tgt_host__ = __tgt__ = "llvm" +__relay_tvm_context__ = tvm.cpu() + +ADD_ID = "__add__" +SUB_ID = "__sub__" +MUL_ID = "__mul__" +DIV_ID = "__div__" +NEG_ID = "__neg__" +LT_ID = "__lt__" +LE_ID = "__le__" +GT_ID = "__gt__" +GE_ID = "__ge__" +EQ_ID = "__eq__" +NE_ID = "__ne__" + +@register_relay_node +class Environment(NodeBase): + """The global Relay environment containing definitions, + primitives, options, and more. + """ + def add(self, item: Item) -> None: + return Environment_add(self, item) + + def global_id(self, name: str) -> GlobalId: + return Environment_global_id(self, name) + + def operator_id(self, name: str) -> OperatorId: + return Environment_operator_id(self, name) + + def lookup(self, ident: Union[GlobalId, OperatorId]) -> Item: + if isinstance(ident, OperatorId): + return Environment_lookup_operator(self, ident) + else: + return Environment_lookup_global(self, ident) + + def add_source(self, file_name: str, source: str) -> FileId: + return Environment_add_source(self, file_name, source) + + def report_error(self, message: str, span: Span) -> None: + return Environment_report_error(self, message, span) + + def register_shape_ext(self, ext: ShapeExtension) -> None: + return Environment_register_shape_ext(self, ext) + + def display_errors(self) -> None: + return Environment_display_errors(self) + + def operators(self) -> List[Operator]: + return Environment_get_operators(self) + + def defns(self) -> List[Defn]: + return Environment_get_defns(self) + + def tvm_context(self): + return __relay_tvm_context__ + + def add_id(self) -> OperatorId: + return self.operator_id(ADD_ID) + + def sub_id(self) -> OperatorId: + return self.operator_id(SUB_ID) + + def mul_id(self) -> OperatorId: + return self.operator_id(MUL_ID) + + def div_id(self) -> OperatorId: + return self.operator_id(DIV_ID) + + def neg_id(self) -> OperatorId: + return self.operator_id(NEG_ID) + + def lt_id(self) -> OperatorId: + return self.operator_id(LT_ID) + + def le_id(self) -> OperatorId: + return self.operator_id(LE_ID) + + def gt_id(self) -> OperatorId: + return self.operator_id(GT_ID) + + def ge_id(self) -> OperatorId: + return self.operator_id(GE_ID) + + def eq_id(self) -> OperatorId: + return self.operator_id(EQ_ID) + + def ne_id(self) -> OperatorId: + return self.operator_id(NE_ID) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 8fa9b789f53c..2b2cdb432b43 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -64,11 +64,11 @@ def bind(self, name, type, value): return lv - def let(self, name, value): + def let(self, name, value, value_type=None): if not isinstance(value, expr.Expr): value = into_ast(value) - return self.bind(name, None, value) + return self.bind(name, value_type, value) def function(self, params): def _on_exit(): diff --git a/python/tvm/relay/type_infer.py b/python/tvm/relay/type_infer.py new file mode 100644 index 000000000000..17938dfdcbc4 --- /dev/null +++ b/python/tvm/relay/type_infer.py @@ -0,0 +1,6 @@ +#pylint: disable-all + +from . import _type_infer + +check_expr = _type_infer.check_expr +# generalize = _type_infer.generalize diff --git a/src/relay/compiler/typechecker.cc b/src/relay/compiler/type_infer.cc similarity index 98% rename from src/relay/compiler/typechecker.cc rename to src/relay/compiler/type_infer.cc index e16481b7f9e0..0b7435598d6d 100644 --- a/src/relay/compiler/typechecker.cc +++ b/src/relay/compiler/type_infer.cc @@ -1,10 +1,10 @@ /*! * Copyright (c) 2018 by Contributors - * \file typechecker.cc - * \brief Relay typechecker + * \file type_infer.cc + * \brief Relay type inference and checking. */ -#include "tvm/relay/compiler/typechecker.h" +#include "tvm/relay/compiler/type_infer.h" #include "./incomplete_type.h" // #include "tvm/relay/alpha_eq.h" // #include "tvm/relay/debug.h" @@ -724,12 +724,12 @@ namespace relay { // } // } -// TVM_REGISTER_API("relay._tyck.check_expr") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// Expr e = args[1]; -// *ret = check(env, e); -// }); +TVM_REGISTER_API("relay._type_infer.check_expr") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + Expr e = args[1]; + *ret = check(env, e); + }); // TVM_REGISTER_API("relay._tyck.check_item") // .set_body([](TVMArgs args, TVMRetValue *ret) { @@ -738,11 +738,11 @@ namespace relay { // *ret = check(env, i); // }); -// TVM_REGISTER_API("relay._tyck.get_checked_type") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Expr e = args[0]; -// *ret = e->checked_type(); -// }); +TVM_REGISTER_API("relay._type_infer._get_checked_type") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Expr e = args[0]; + *ret = e->checked_type(); + }); // TVM_REGISTER_API("relay._tyck.generalize") // .set_body([](TVMArgs args, TVMRetValue *ret) { diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py new file mode 100644 index 000000000000..e5466d4439b9 --- /dev/null +++ b/tests/python/relay/test_typechecker.py @@ -0,0 +1,17 @@ +"""Test that type checker correcly computes types + for expressions. +""" +import tvm.relay.make as mk +from tvm.relay.ir_builder import IRBuilder, float_type + +def test_monomorphic_let(): + b = IRBuilder() + # Program: let x = 1; x + x = b.let('x', 1, value_type=float_type()) + b.ret(x) + + prog = b.get() + e = check_expr(prog) + e.get_type() + + From 17a4ab35bee17f67c80544dbb10e1185aa7a819f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 01:13:44 -0700 Subject: [PATCH 25/88] Get a failing test for the type checker --- include/tvm/relay/compiler/type_infer.h | 8 ++++++-- python/tvm/relay/expr.py | 4 +++- python/tvm/relay/make.py | 3 +++ src/relay/compiler/environment.cc | 18 +++++++++--------- src/relay/compiler/type_infer.cc | 13 +++++++------ tests/python/relay/test_typechecker.py | 9 +++++++-- 6 files changed, 35 insertions(+), 20 deletions(-) diff --git a/include/tvm/relay/compiler/type_infer.h b/include/tvm/relay/compiler/type_infer.h index 4c16defe977f..6d07de1c29e8 100644 --- a/include/tvm/relay/compiler/type_infer.h +++ b/include/tvm/relay/compiler/type_infer.h @@ -19,8 +19,12 @@ namespace relay { * with unambigous type information filled in, as well as it's * checked type field populated with the result type. */ -Expr check(const Environment & env, const Expr & e); -Operator check(const Environment & env, const Operator & op); +Expr Infer(const Environment & env, const Expr & e); + +/*! \brief Ensures that an operator is well-formed with respect + * to Relay's type system. + */ +Operator CheckOperator(const Environment & env, const Operator & op); } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index dea3a99f5f09..c17a69dd0dc9 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -6,10 +6,12 @@ from .base import Span, NodeBase, register_relay_node from .type import Type, TypeParam from tvm import expr +from ._type_infer import _get_checked_type class Expr(NodeBase): """The base type for all Relay exprressions.""" - pass + def checked_type(self): + return _get_checked_type(self) @register_relay_node class Constant(Expr): diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py index 236e2f6af596..bf9ec0e48f64 100644 --- a/python/tvm/relay/make.py +++ b/python/tvm/relay/make.py @@ -4,6 +4,9 @@ # Base Constructors Span = _make.Span +# Environment +Environment = _make.Environment + # Type Constructors TensorType = _make.TensorType TypeParam = _make.TypeParam diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc index 125ceae834b3..af8f5eeefab7 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/compiler/environment.cc @@ -187,10 +187,10 @@ Environment EnvironmentNode::make( // this->shape_exts_.Insert(ext->name, ext); // } -// TVM_REGISTER_API("relay._make.Environment") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// *ret = EnvironmentNode::make({}); -// }); +TVM_REGISTER_API("relay._make.Environment") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = EnvironmentNode::make({}); + }); // TVM_REGISTER_API("relay._env.Environment_add") // .set_body([](TVMArgs args, TVMRetValue *ret) { @@ -282,11 +282,11 @@ Environment EnvironmentNode::make( // *ret = env->get_defns(); // }); -// TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -// .set_dispatch([](const EnvironmentNode *node, -// tvm::IRPrinter *p) { -// p->stream << "EnvironmentNode(todo)"; // << node->items << ")"; -// }); +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const EnvironmentNode *node, + tvm::IRPrinter *p) { + p->stream << "EnvironmentNode(todo)"; // << node->items << ")"; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 0b7435598d6d..96d9dc92d97e 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -625,12 +625,13 @@ namespace relay { // } // } -// Type check(const Environment &env, const Expr &e) { -// Typechecker tc(env); -// return tc.Check(e); -// } +Expr Infer(const Environment &env, const Expr &e) { + //Typechecker tc(env); + // return tc.Check(e); + return e; +} -// Item check(const Environment &env, const Item &i) { +// Item Check(const Environment &env, const Item &i) { // Typechecker tc(env); // try { @@ -728,7 +729,7 @@ TVM_REGISTER_API("relay._type_infer.check_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { Environment env = args[0]; Expr e = args[1]; - *ret = check(env, e); + *ret = Infer(env, e); }); // TVM_REGISTER_API("relay._tyck.check_item") diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index e5466d4439b9..5626fd8ce0bc 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -2,8 +2,14 @@ for expressions. """ import tvm.relay.make as mk +from tvm.relay.type_infer import check_expr from tvm.relay.ir_builder import IRBuilder, float_type +def has_type(expr, typ): + env = mk.Environment({}) + checked_expr = check_expr(env, expr) + return checked_expr.checked_type() == typ + def test_monomorphic_let(): b = IRBuilder() # Program: let x = 1; x @@ -11,7 +17,6 @@ def test_monomorphic_let(): b.ret(x) prog = b.get() - e = check_expr(prog) - e.get_type() + assert has_type(prog, float_type()) From 66a80bdbecbe3c7a83cb36e9a38051d8317c3a95 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 02:24:00 -0700 Subject: [PATCH 26/88] Iterate on first test case --- include/tvm/relay/error.h | 10 + include/tvm/relay/expr.h | 8 +- include/tvm/relay/expr_visitor.h | 18 +- src/relay/compiler/resolve.cc | 99 ++ src/relay/compiler/resolve.h | 23 + src/relay/compiler/type_infer.cc | 1461 +++++++++++++++--------------- src/relay/compiler/unifier.h | 5 +- 7 files changed, 870 insertions(+), 754 deletions(-) create mode 100644 src/relay/compiler/resolve.cc create mode 100644 src/relay/compiler/resolve.h diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index d2698f8e380b..4f6a27d209c8 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -22,6 +22,16 @@ struct SpannedError { SpannedError(std::string msg, Span sp) : msg(msg), sp(sp) {} }; +// FIX, we should change spanned errors to have a method which allow them to report on the Environment, +// inverting control to error definition. +struct FatalTypeError : dmlc::Error { + explicit FatalTypeError(const std::string & s) : dmlc::Error(s) {} +}; + +struct TypecheckerError : public dmlc::Error { + explicit TypecheckerError(const std::string &msg) : Error(msg) {} +}; + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index c1dd557717af..a29c8486ffb6 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -15,6 +15,10 @@ namespace tvm { namespace relay { + +// TOD0(@jroesch): best way to define? +class TypeInferencer; + /*! * \brief Relay expression. */ @@ -24,13 +28,14 @@ class Expr; */ class ExprNode : public RelayNode { public: + // private: /*! * \brief Stores the result of type inference(type checking). * * \note This can be undefined before type inference. * this value is discarded during serialization. */ - Type checked_type_ = Type(nullptr); + mutable Type checked_type_ = Type(nullptr); /*! * \return The checked_type */ @@ -43,6 +48,7 @@ class ExprNode : public RelayNode { static constexpr const char* _type_key = "relay.Expr"; TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); + friend class TypeInferencer; }; RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 721fa531a7e3..2039414b4238 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -97,9 +97,16 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor ty_params; + tvm::Array ty_params; + for (auto ty : op->type_params) { - ty_params.push_back(this->VisitType(ty, args...)); + Type ty_param_type = VisitType(ty, args...); + if (auto ty_param = ty_param_type.as()) { + auto ty_param_ref = GetRef(ty_param); + ty_params.push_back(ty_param_ref); + } else { + throw dmlc::Error("the default func visitor has bug"); + } } tvm::Array params; @@ -115,7 +122,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitType(op->ret_type, args...); auto body = this->VisitExpr(op->body, args...); - return FunctionNode::make(ty_params, params, ret_type, body); + return FunctionNode::make(params, ret_type, body, ty_params); } Expr VisitExpr_(const CallNode* call_node, Args... args) override { @@ -132,8 +139,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitExpr(arg, args...)); } - auto call = CallNode::make(fn, call_args, call_node->attrs); - call->ty_args = ty_args; + auto call = CallNode::make(fn, call_args, call_node->attrs, ty_args); return call; } @@ -145,7 +151,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitType(op->value_type, args...); auto value = this->VisitExpr(op->value, args...); auto body = this->VisitExpr(op->body, args...); - return LetNode::make(var, type, value, body); + return LetNode::make(var, value, body, type); } else { throw dmlc::Error("the default let visitor has error"); } diff --git a/src/relay/compiler/resolve.cc b/src/relay/compiler/resolve.cc new file mode 100644 index 000000000000..2d3e84dc2160 --- /dev/null +++ b/src/relay/compiler/resolve.cc @@ -0,0 +1,99 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file unifier.cc + * \brief Data structures for type unification + */ + +#include "./resolve.h" +#include "./type_visitor.h" +#include "tvm/relay/expr_visitor.h" +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +// We should probably generalize the subst code. +struct ResolveTypeType : TypeFVisitor { + const TypeUnifier &unifier; + + explicit ResolveTypeType(const TypeUnifier &unifier) : unifier(unifier) {} + + Type VisitType(const Type &t) override { + if (!t.defined()) { + auto inc_ty = IncompleteTypeNode::make(TypeParamNode::Kind::kType); + unifier->insert(inc_ty); + return inc_ty; + } else { + return TypeFVisitor::VisitType(t); + } + } + + Type VisitType_(const IncompleteTypeNode *op) override { + return unifier->subst(GetRef(op)); + } +}; + +struct ResolveTypeExpr : ExprFVisitor<> { + const TypeUnifier &unifier; + + explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {} + + Expr VisitExpr(const Expr &e) { + // NB: a bit tricky here. + // + // We want to store resolved type without having + // to re-typecheck the entire term. + // + // Since we know that e : T[...] under some holes + // then it is the case that if we resolve types + // present in e, then we can type it under T + // with the wholes filled in. + // + // We will visit e like normal building a new + // term, then resolve e's old type and write + // it back into the new node. + auto new_e = ExprFVisitor::VisitExpr(e); + auto resolved_cty = VisitType(e->checked_type_); + new_e->checked_type_ = resolved_cty; + return new_e; + } + + Type VisitType(const Type &t) { + return ResolveTypeType(unifier).VisitType(t); + } +}; + +Type resolve(const TypeUnifier &unifier, const Type &ty) { + return ResolveTypeType(unifier).VisitType(ty); +} + +Expr resolve(const TypeUnifier &unifier, const Expr &expr) { + return ResolveTypeExpr(unifier).VisitExpr(expr); +} + +struct FullyResolved : TypeVisitor<> { + bool incomplete; + + FullyResolved() : incomplete(true) {} + + void VisitType(const Type &t) override { + if (!t.defined()) { + incomplete = true; + } else { + return TypeVisitor<>::VisitType(t); + } + } + + void VisitType_(const IncompleteTypeNode *ty_var) override { + incomplete = false; + } +}; + +bool is_fully_resolved(const Type &t) { + auto fr = FullyResolved(); + fr.VisitType(t); + return fr.incomplete; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/compiler/resolve.h b/src/relay/compiler/resolve.h new file mode 100644 index 000000000000..b4e164df6287 --- /dev/null +++ b/src/relay/compiler/resolve.h @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/options.h + * \brief Global options for the Relay IR. + */ +#ifndef TVM_RELAY_TYPECK_RESOLVE_H_ +#define TVM_RELAY_TYPECK_RESOLVE_H_ + +#include +#include "tvm/relay/ir.h" +#include "./unifier.h" + +namespace tvm { +namespace relay { + +Type resolve(const TypeUnifier & unifier, const Type & ty); +Expr resolve(const TypeUnifier & unifier, const Expr & expr); +bool is_fully_resolved(const Type & t); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TYPECK_RESOLVE_H_ diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 96d9dc92d97e..49c8bbf9627f 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -2,771 +2,740 @@ * Copyright (c) 2018 by Contributors * \file type_infer.cc * \brief Relay type inference and checking. + * + * This file implements one of the most important passes to the + * Relay IR. In order to do many transformations and generate the + * most efficient code we need to obtain type information for the + * IR. + * + * Like computation graphs the IR leaves most type information + * implicit and relies performing analysis of the program to + * generate this information. + * + * This pass given an expression `e` will infer a type `t` for + * the expression simultaneous checking the property `e : t` + * (i.e we can show e has type t). + * + * If we can not infer a type or there are conflicting typing + * constraints we will trigger an error. */ +#include "tvm/relay/logging.h" #include "tvm/relay/compiler/type_infer.h" +#include "tvm/relay/error.h" +#include "tvm/relay/expr_functor.h" #include "./incomplete_type.h" +#include "./unifier.h" +#include "./resolve.h" // #include "tvm/relay/alpha_eq.h" // #include "tvm/relay/debug.h" // #include "tvm/relay/first_order_reverse_ad.h" // #include "tvm/relay/free_type_vars.h" // #include "tvm/relay/gen_fresh.h" // #include "tvm/relay/ir.h" -// #include "tvm/relay/logging.h" // #include "tvm/relay/pretty_printer.h" // #include "tvm/relay/reverse_ad.h" // #include "tvm/relay/type_visitor.h" // #include "tvm/relay/typeck/kindchecker.h" -// #include "tvm/relay/typeck/resolve.h" // #include "tvm/relay/typeck/shape_evaluator.h" namespace tvm { namespace relay { -// using namespace tvm::runtime; - -// struct FatalTypeError : dmlc::Error { -// explicit FatalTypeError(const std::string & s) : dmlc::Error(s) {} -// }; - -// struct TypeContext { -// std::vector> stack; -// TypeContext() { -// stack.push_back({}); -// } -// void insert(const LocalId &id, const Type &t) { stack.back()[id] = t; } -// Type lookup(const LocalId &id) { -// for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { -// if (frame->find(id) != frame->end()) { -// return frame->at(id); -// } -// } -// throw FatalTypeError("Could not resolve local id"); -// } -// struct LocalFrame { -// TypeContext & tc; -// explicit LocalFrame(TypeContext & tc) : tc(tc) { -// tc.stack.push_back({}); -// } -// ~LocalFrame() { -// tc.stack.pop_back(); -// } -// }; -// }; - -// class Typechecker : private ExprFunctor { -// private: -// TypeContext local_stack; -// public: -// Environment env; -// TypeUnifier unifier; - -// template -// T with_frame(const std::function & f) { -// TypeContext::LocalFrame fr(local_stack); -// return f(); -// } - -// Typechecker(); -// Typechecker(Environment env, TypeUnifier unifier) : env(env), unifier(unifier) {} -// explicit Typechecker(Environment env); -// Type Check(const Expr & expr); -// Type instantiate(Type t, tvm::Array & ty_args); - -// void report_error(const std::string & msg, Span sp); -// [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); - -// Type unify(const Type &t1, const Type &t2, Span sp); -// Type resolve(const Type &t); -// Expr resolve(const Expr &e); -// Type VisitFunction(const Function & f, bool generalize); -// Operator CheckOp(Operator op); -// Defn CheckDefn(Defn def); -// private: -// Type VisitExpr_(const LocalIdNode* op) override; -// Type VisitExpr_(const GlobalIdNode* op) override; -// Type VisitExpr_(const OperatorIdNode* op) override; -// Type VisitExpr_(const FloatLitNode* op) override; -// Type VisitExpr_(const BoolLitNode* op) override; -// Type VisitExpr_(const IntLitNode* op) override; -// Type VisitExpr_(const TensorLitNode* op) override; -// Type VisitExpr_(const TupleNode* op) override; -// Type VisitExpr_(const CastNode* op) override; -// Type VisitExpr_(const ParamNode* op) override; -// Type VisitExpr_(const FunctionNode* op) override; -// Type VisitExpr_(const CallNode* op) override; -// Type VisitExpr_(const DebugNode* op) override; -// Type VisitExpr_(const LetNode* op) override; -// Type VisitExpr_(const ReverseNode* op) override; -// Type VisitExpr_(const GradientNode* op) override; -// Type VisitExpr_(const ProjectionNode* op) override; -// Type VisitExpr_(const IfNode* op) override; -// Type VisitExpr_(const RefNode* op) override; -// Type VisitExpr_(const ReadRefNode* op) override; -// Type VisitExpr_(const WriteRefNode* op) override; -// Type simple_eval_shape(const Type &shape); -// }; -// struct TypecheckerError : public dmlc::Error { -// explicit TypecheckerError(const std::string &msg) : Error(msg) {} -// }; - -// Typechecker::Typechecker() { -// this->env = EnvironmentNode::make({}); -// this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); -// } - -// Typechecker::Typechecker(Environment env) : env(env) { -// this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); -// } - -// Type Typechecker::Check(const Expr &expr) { -// RELAY_LOG(INFO) << "Typechecker::Check expr=" << expr << std::endl; -// Type ret = this->VisitExpr(expr); -// RELAY_LOG(INFO) << "Typechecker::Check type=" << expr << std::endl; -// ret = this->unifier->subst(ret); -// RELAY_LOG(INFO) << "Typechecker::Check type_after_subst=" << ret << std::endl; -// expr->checked_type_ = ret; -// return ret; -// } - -// Type Typechecker::VisitExpr_(const LocalIdNode *op) { -// LocalId id = GetRef(op); -// return this->local_stack.lookup(id); -// } - -// Type Typechecker::VisitExpr_(const GlobalIdNode *op) { -// GlobalId id = GetRef(op); -// Item item = this->env->lookup(id); - -// if (const OperatorNode *op = item.as()) { -// return op->type; -// } - -// if (const DefnNode *dn = item.as()) { -// Defn def = GetRef(dn); -// return def->type; -// } - -// this->fatal_error("Unhandled case in GlobalId", op->span); -// } - -// Type Typechecker::VisitExpr_(const OperatorIdNode *op) { -// OperatorId id = GetRef(op); -// Item item = this->env->lookup(id); - -// if (const OperatorNode *pn = item.as()) { -// Operator prim = GetRef(pn); -// return prim->type; -// } else { -// this->fatal_error("internal error in InstrinsicId case", op->span); -// } -// } - -// Type Typechecker::VisitExpr_(const FloatLitNode *op) { return FloatType(); } - -// Type Typechecker::VisitExpr_(const BoolLitNode *op) { return BoolType(); } - -// Type Typechecker::VisitExpr_(const IntLitNode *op) { return IntType(); } - -// Type Typechecker::VisitExpr_(const TensorLitNode *op) { -// TensorLit lit = GetRef(op); - -// if (lit->data.size() == 0) { -// this->fatal_error("Tensor literal must have at least one member", op->span); -// } - -// // unify types of all members to figure out shape, also ensure that -// // each member has compatible shape -// Type unified = this->Check(lit->data[0]); -// for (auto elt = lit->data.begin(); elt != lit->data.end(); elt++) { -// // evaluate all shape ASTs so they can be in standard form -// // TODO(sslyu): eventually we'd want this to be symbolic evaluation -// auto elt_el = *elt; -// Type elt_type = simple_eval_shape(this->Check(*elt)); -// if (!elt_type.as()) { -// this->fatal_error("All members in tensor literal must be tensors", -// elt_el->span); -// } -// unified = this->unify(unified, elt_type, lit->span); -// } - -// // types must unify into a tensor -// const TensorTypeNode *ttn = unified.as(); -// // shouldn't be possible due to check inside the loop -// if (!ttn) { -// this->fatal_error("Tensor literal contains non-tensor member", op->span); -// } - -// TensorType unified_tt = GetRef(ttn); - -// // new shape: add length of this tensor to front of existing shape -// // i.e., sequence and simplify -// // TODO(sslyu): should be symbolic evaluation eventually? -// Type new_shape = ShapeSeqNode::make( -// {ShapeSingletonNode::make(lit->data.size()), unified_tt->shape}); -// return TensorTypeNode::make(unified_tt->dtype, simple_eval_shape(new_shape)); -// } - -// Type Typechecker::VisitExpr_(const TupleNode *op) { -// Tuple pl = GetRef(op); - -// std::vector field_types; -// for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) { -// field_types.push_back(this->Check(*field)); -// } - -// return TupleTypeNode::make(field_types); -// } - -// Type Typechecker::VisitExpr_(const CastNode *op) { -// // will take the cast at its word -// Cast cast = GetRef(op); -// return cast->target; -// } - -// Type Typechecker::VisitExpr_(const ParamNode *op) { -// Param param = GetRef(op); -// return resolve(param->type); -// } - -// // We should probably generalize the subst code. -// struct GeneralizeTypeType : TypeFVisitor { -// Map vars_to_id; -// const TypeUnifier &unifier; - -// GeneralizeTypeType(Map vars_to_id, -// const TypeUnifier &unifier) -// : vars_to_id(vars_to_id), unifier(unifier) {} - -// Type VisitType_(const TypeVarNode *op) override { -// auto repr = unifier->subst(GetRef(op)); -// if (auto tvn = repr.as()) { -// auto ty_var = GetRef(tvn); -// if (vars_to_id.find(ty_var) != vars_to_id.end()) { -// return vars_to_id[ty_var]; -// } else { -// return ty_var; -// } -// } else { -// return this->VisitType(repr); -// } -// } -// }; - -// struct GeneralizeTypeExpr : ExprFVisitor<> { -// Map vars_to_id; -// const TypeUnifier &unifier; - -// GeneralizeTypeExpr(const TypeUnifier &unifier, -// Map vars_to_id) -// : vars_to_id(vars_to_id), unifier(unifier) {} - -// Type VisitType(const Type &t) { -// return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); -// } -// }; - -// Type Typechecker::VisitFunction(const Function &f, bool generalize) { -// // enter params into context -// auto fn_type = this->with_frame([&]() { -// std::vector arg_types; -// for (auto arg : f->params) { -// this->Check(arg); -// Type arg_type; -// // if arg type can be simply evaluated, try it -// // should be replaced with symbolic evaluation once it exists, -// // you will not have attr information at this point -// try { -// arg_type = simple_eval_shape(arg->type); -// } catch (const dmlc::Error &e) { -// this->report_error(e.what(), arg->span); -// arg_type = arg->type; -// } -// arg_types.push_back(arg_type); -// this->local_stack.insert(arg->id, arg_type); -// } - -// // typecheck body and ensure that it matches stated return type -// // TODO(sslyu): should the unified return type override the annotated one? -// Type checked_return = this->Check(f->body); -// Type ret_type = resolve(f->ret_type); -// Type unified = this->unify(simple_eval_shape(ret_type), -// simple_eval_shape(checked_return), f->span); -// return TypeArrowNode::make(arg_types, unified); -// }); -// if (generalize) { -// auto free_vars = free_type_vars(resolve(fn_type)); -// std::set dedup_free_vars; - -// for (auto free_var : free_vars) { -// auto repr = this->unifier->subst(free_var); -// if (auto new_free_var_node = repr.as()) { -// dedup_free_vars.insert(GetRef(new_free_var_node)); -// } else { -// // debug(repr); -// throw dmlc::Error( -// "internal error: this list should only contain type var nodes"); -// } -// } - -// Map vars_to_id; - -// GenFresh gf; -// for (auto free_var : dedup_free_vars) { -// vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); -// } - -// fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); -// for (std::pair pair : vars_to_id) { -// // NB: In generalization we want to find type variables with -// // *no constraints* on them, and convert them to universally quantified -// // variables. -// // -// // i.e the program can be abstracted over the details of *that* type. - -// // For example a program that works irrespective of shape or datatype. - -// // In order to do this we find the set of free type variables in the -// // term, and then unify them with the fresh type ids we generate. -// // -// // Remember importantly these type variables still may appear in many -// // places in the program including both types and expressions. - -// // Our method for resolving these is to unify them with the variables -// // as we build the new quanitifer, changing from a program with "holes" -// // to one that is properly abstracted over. - -// // Finally later on we can iterate over the whole term and change from -// // type variables to these type ids. -// this->unify(pair.first, pair.second, pair.second->span); -// fn_type = TypeQuantifierNode::make(pair.second, fn_type); -// } -// } else { -// for (auto i = f->ty_params.size(); i > 0; i--) { -// auto ty_param = f->ty_params[i - 1]; -// auto ty_param_node = ty_param.as(); -// if (!ty_param_node) { -// throw dmlc::Error("internal error should be TypeParam"); -// } -// auto fresh_tid = -// TypeParamNode::make(ty_param_node->name, ty_param_node->kind); -// fn_type = -// type_subst(fn_type, GetRef(ty_param_node), fresh_tid); -// fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); -// } -// } - -// return fn_type; -// } - -// Type Typechecker::VisitExpr_(const FunctionNode *op) { -// return this->VisitFunction(GetRef(op), false); -// } - -// Type Typechecker::instantiate(Type t, tvm::Array &ty_args) { -// const TypeQuantifierNode *ty_quant; -// while ((ty_quant = t.as())) { -// TypeParam id = ty_quant->id; -// TypeVar fresh = TypeVarNode::make(id->kind); -// this->unifier->insert(fresh); -// ty_args.push_back(fresh); -// t = type_subst(ty_quant->boundType, id, fresh); -// } - -// if (!check_kind(t)) { -// this->fatal_error("Kind rules broken when instantiating type variables", -// t->span); -// } - -// return t; -// } - -// Type Typechecker::VisitExpr_(const CallNode *op) { -// Call c = GetRef(op); -// Type fn_ty = this->Check(c->fn); - -// RELAY_LOG(INFO) << "Typechecker::VisitExpr_ op=" << c << std::endl -// << "fn_ty=" << fn_ty << std::endl; - -// // for each type id, insert a type variable and unify with the argument types -// // in order -// // to obtain the concrete instantiation -// tvm::Array ty_args; -// if (const TypeQuantifierNode *ty_quant = fn_ty.as()) { -// fn_ty = instantiate(GetRef(ty_quant), ty_args); -// } - -// if (!fn_ty.as()) { -// this->fatal_error("only expressions with function types can be called", -// c->fn->span); -// } - -// // evaluate all shapes up front (require that types be fully concrete) -// Type evaluated = evaluate_concrete_shape(fn_ty, op->attrs); -// std::vector arg_types; - -// TypeArrow arrow = GetRef(evaluated.as()); - -// // TODO(sslyu): figure out how to handle type ids -// // fn_ty = instantiate(fn_ty, ty_args); -// for (auto arg : c->args) { -// auto ty = this->Check(arg); -// arg_types.push_back(ty); -// } - -// auto type_arity = arrow->arg_types.size(); -// auto number_of_args = arg_types.size(); -// if (type_arity != number_of_args) { -// if (type_arity < number_of_args) { -// this->fatal_error("the function is provided too many arguments", c->span); -// } else { -// this->fatal_error("the function is provided too few arguments", c->span); -// } -// } - -// for (size_t i = 0; i < arrow->arg_types.size(); i++) { -// this->unify(arrow->arg_types[i], arg_types[i], c->args[i]->span); -// } - -// // After we unify the arguments we should know more about the type -// // arguments, let's run a quick pass over them to find new representatives. -// for (size_t i = 0; i < ty_args.size(); i++) { -// ty_args.Set(i, this->unifier->subst(ty_args[i])); -// } - -// // Write the type arguments into the call node, recording what inference -// // solves. This solution might need some work. -// c->ty_args = ty_args; - -// return arrow->ret_type; -// } - -// Type Typechecker::VisitExpr_(const DebugNode *op) { -// return this->Check(op->node); -// } - -// Type Typechecker::VisitExpr_(const LetNode *op) { -// Let let = GetRef(op); - -// Type checked_ty; -// Type annotated_ty = resolve(let->type); - -// // if we are let-defining a function, treat it as a let-rec and insert -// // the id with the annotated type in case there is recursion; -// // no such recursion permitted with anything that's not a function! -// if (let->value.as()) { -// with_frame([&]() { -// local_stack.insert(let->id, annotated_ty); -// checked_ty = Check(let->value); -// }); -// } else { -// checked_ty = Check(let->value); -// } - -// // ensure annotated type and checked type are compatible -// // TODO(sslyu): should the annotated type override the unified one? -// Type unified_ty = -// this->unify(checked_ty, simple_eval_shape(annotated_ty), let->span); - -// return with_frame([&]() { -// local_stack.insert(let->id, unified_ty); -// return Check(let->body); -// }); -// } - -// Type Typechecker::VisitExpr_(const ReverseNode *op) { -// // apply reverse mode to node and typecheck that instead -// std::shared_ptr gf = std::make_shared(); -// return this->Check(ReverseExpr(env, op->node, gf)); -// } - -// Type Typechecker::VisitExpr_(const GradientNode *op) { -// auto node = op->node; -// this->Check(node); -// auto gf = std::make_shared(); -// return FOWithGradientType(node->checked_type()); -// } - -// Type Typechecker::VisitExpr_(const ProjectionNode *op) { -// Projection proj = GetRef(op); - -// Type tup_type = this->Check(proj->tuple); - -// const TupleTypeNode *ptn = tup_type.as(); -// if (!ptn) { -// this->fatal_error("Cannot project into non-product type", op->span); -// } - -// TupleType pt = GetRef(ptn); -// size_t field = (size_t)proj->field; -// if (field >= pt->fields.size()) { -// this->fatal_error("Projecting past bounds of product", op->span); -// } - -// return pt->fields[field]; -// } - -// Type Typechecker::VisitExpr_(const IfNode *op) { -// If ifn = GetRef(op); - -// // Ensure the type of the guard is of Tensor[Bool, ()], -// // that is a rank-0 boolean tensor. -// Type guardType = this->Check(ifn->guard); -// bool is_bool = false; -// bool zero_rank = false; -// if (const TensorTypeNode *ttn = guardType.as()) { -// TensorType tt = GetRef(ttn); - -// if (const BaseTypeNode *btn = tt->dtype.as()) { -// is_bool = btn->type.is_bool(); -// } - -// Type shape = simple_eval_shape(tt->shape); - -// if (const ShapeSeqNode *sn = shape.as()) { -// zero_rank = (sn->shapes.size() == 0); -// } -// } - -// if (!(is_bool && zero_rank)) { -// this->fatal_error("IfNode guard must be a rank 0 bool tensor", -// ifn->guard->span); -// } - -// // unify types of different branches -// Type left = this->Check(ifn->true_b); -// Type right = this->Check(ifn->false_b); -// return this->unify(left, right, ifn->span); -// } - -// Type Typechecker::VisitExpr_(const RefNode *op) { -// Ref r = GetRef(op); -// Type inner = this->Check(r->expr); -// return RefTypeNode::make(inner); -// } - -// Type Typechecker::VisitExpr_(const ReadRefNode *op) { -// ReadRef vr = GetRef(op); -// Type ref_type = this->Check(vr->ref); - -// // reject if not a ref type -// const RefTypeNode *rtn = ref_type.as(); -// if (!rtn) { -// this->fatal_error( -// "the de-reference operation can only be used with references", -// op->span); -// } - -// RefType rt = GetRef(rtn); -// return rt->data_type; -// } - -// Type Typechecker::VisitExpr_(const WriteRefNode *op) { -// WriteRef sr = GetRef(op); -// Type ref_type = this->Check(sr->ref); - -// const RefTypeNode *rtn = ref_type.as(); -// if (!rtn) { -// this->fatal_error("Cannot mutate non-ref", op->span); -// } -// RefType rt = GetRef(rtn); - -// // ensure ref type's inner type and expr's type are compatible; return unit -// Type expr_type = this->Check(sr->val); -// this->unify(rt->data_type, expr_type, sr->span); -// return UnitType(); -// } - -// Type Typechecker::resolve(const Type &t) { -// return ::tvm::relay::resolve(this->unifier, t); -// } - -// Expr Typechecker::resolve(const Expr &e) { -// return ::tvm::relay::resolve(this->unifier, e); -// } - -// Type Typechecker::simple_eval_shape(const Type &shape) { -// // TODO(sslyu): Do we want to propagate attributes? -// Attributes empty = AttributesNode::make({}); -// return evaluate_concrete_shape(shape, empty); -// } - -// Operator Typechecker::CheckOp(Operator op) { -// if (!check_kind(op->type)) { -// report_error("the type of the operator is ill formed", op->type->span); -// } - -// // Fix me -// return op; -// } - -// Defn Typechecker::CheckDefn(Defn defn) { -// // This is to handle recursion, but we need to speculatively -// // put it in env, then remove it. -// env->items.insert({defn->id, defn}); - -// Type expected_ty = this->resolve(defn->type); - -// Expr body = defn->body; - -// auto checked_ty = Check(body); - -// try { -// Type uret_type = unify(expected_ty, checked_ty, defn->body->span); -// CHECK(is_fully_resolved(uret_type)); -// // Now let's clean up our work from earlier. -// env->items.erase(defn->id); -// return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); -// } catch (const UnificationError& err) { -// std::string msg = std::string("mismatch between `") + -// PrintType(env, expected_ty, WrapWidth(40)) + "` and `" + -// PrintType(env, checked_ty, WrapWidth(40)) + "`"; -// fatal_error(msg, defn->span); -// } -// } - -Expr Infer(const Environment &env, const Expr &e) { - //Typechecker tc(env); - // return tc.Check(e); - return e; -} - -// Item Check(const Environment &env, const Item &i) { -// Typechecker tc(env); - -// try { -// if (const DefnNode *defn = i.as()) { -// return tc.CheckDefn(GetRef(defn)); -// } else if (const OperatorNode *op_node = i.as()) { -// return tc.CheckOp(GetRef(op_node)); -// } else { -// throw dmlc::Error("internal error: unknown Item type"); -// } -// } catch (const FatalTypeError &err) { -// env->display_errors(); -// throw dmlc::Error( -// "We encountered a fatal error while type checking your program, please " -// "read above for more details."); -// } -// } - -// inline void Typechecker::report_error(const std::string &msg, Span sp) { -// this->env->report_error(msg, sp); -// } - -// void Typechecker::fatal_error(const std::string &msg, Span sp) { -// this->env->report_error(msg, sp); -// throw FatalTypeError( -// "internal error: this exception should" -// "be handled and errors reported with Environment::display_errors\n" + -// msg); -// } - -// Type Typechecker::unify(const Type &t1, const Type &t2, Span sp) { -// try { -// return this->unifier->unify(t1, t2); -// } catch (const dmlc::Error &e) { -// std::stringstream ss; -// ss << "Error unifying `"; -// ss << PrintType(env, t1, WrapWidth(40)); -// ss << "` and `"; -// ss << PrintType(env, t2, WrapWidth(40)); -// ss << "`: " << e.what(); -// this->fatal_error(ss.str(), sp); -// } -// } - -// // template - -// // Add safe dynamic Array downcast. -// // Add static upcast? - -// // Add to type utils. -// Array type_parameters(const Type &t) { -// Array params; -// auto type = t; -// const TypeQuantifierNode *ty_quant; -// while ((ty_quant = type.as())) { -// params.push_back(ty_quant->id); -// type = ty_quant->boundType; -// } - -// return params; -// } - -// template -// Array ArrayMap(const Array &data, F f) { -// // probably a way to use std::transform. -// Array output; -// for (const I &el : data) { -// output.push_back(f(el)); -// } -// return output; -// } - -// // There are some important questions around generalization -// // that we need to answer. -// Expr generalize(const Environment &env, const Expr &e) { -// if (auto fn_node = e.as()) { -// Typechecker tc(env); -// auto ty = tc.VisitFunction(GetRef(fn_node), true); -// auto ty_params = type_parameters(ty); -// auto params = ArrayMap(fn_node->params, [&](const Param &p) { -// return ParamNode::make(p->id, tc.resolve(p->type)); -// }); -// auto body = tc.resolve(fn_node->body); -// auto ret_type = tc.resolve(fn_node->ret_type); -// auto fn = FunctionNode::make(ty_params, params, ret_type, body); -// // we should check in empty context to ensure typing is preserved. -// // check(env, fn); -// return fn; -// } else { -// throw dmlc::Error("can only apply generalize to a function."); -// } -// } - -TVM_REGISTER_API("relay._type_infer.check_expr") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - Expr e = args[1]; - *ret = Infer(env, e); - }); - -// TVM_REGISTER_API("relay._tyck.check_item") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// Item i = args[1]; -// *ret = check(env, i); -// }); - -TVM_REGISTER_API("relay._type_infer._get_checked_type") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Expr e = args[0]; - *ret = e->checked_type(); - }); - -// TVM_REGISTER_API("relay._tyck.generalize") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// *ret = generalize(args[0], args[1]); -// }); - -IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { - std::shared_ptr n = std::make_shared(); - n->kind = std::move(kind); - return IncompleteType(n); -} - -TVM_REGISTER_API("relay._make.IncompleteType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - int kind = args[0]; - *ret = IncompleteTypeNode::make(static_cast(kind)); - }); - -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const IncompleteTypeNode *node, - tvm::IRPrinter *p) { - p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; +using namespace tvm::runtime; + +struct TypeContext { + std::vector> stack; + + TypeContext() { stack.push_back({}); } + + void insert(const LocalVar &id, const Type &t) { stack.back()[id] = t; } + + Type lookup(const LocalVar &id) { + for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { + if (frame->find(id) != frame->end()) { + return frame->at(id); + } + } + throw FatalTypeError("Could not resolve local id"); + } + + struct LocalFrame { + TypeContext &tc; + explicit LocalFrame(TypeContext &tc) : tc(tc) { tc.stack.push_back({}); } + ~LocalFrame() { tc.stack.pop_back(); } + }; +}; + +struct CheckedExpr { + Expr expr; + Type type; + CheckedExpr(Expr e, Type t) : expr(e), type(t) {} +}; + +class TypeInferencer : private ExprFunctor { + private: + TypeContext local_stack; + + public: + Environment env; + TypeUnifier unifier; + + // Should be in header? + template + T with_frame(const std::function & f) { + TypeContext::LocalFrame fr(local_stack); + return f(); + } + + TypeInferencer(); + TypeInferencer(Environment env, TypeUnifier unifier) : env(env), + unifier(unifier) {} explicit TypeInferencer(Environment env); + + CheckedExpr Infer(const Expr & expr); + + Type instantiate(Type t, tvm::Array &ty_args); + + void report_error(const std::string & msg, Span sp); + [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); + + Type unify(const Type &t1, const Type &t2, Span sp); + Type resolve(const Type &t); + Expr resolve(const Expr &e); + CheckedExpr VisitFunction(const Function & f, bool generalize); + // Operator CheckOp(Operator op); + // Defn CheckDefn(Defn def); + private: + CheckedExpr VisitExpr_(const LocalVarNode* op) override; + CheckedExpr VisitExpr_(const GlobalVarNode* op) override; + CheckedExpr VisitExpr_(const TupleNode* op) override; + CheckedExpr VisitExpr_(const ParamNode* op) override; + CheckedExpr VisitExpr_(const FunctionNode* op) override; + CheckedExpr VisitExpr_(const CallNode* op) override; + CheckedExpr VisitExpr_(const LetNode* op) override; + CheckedExpr VisitExpr_(const IfNode* op) override; +}; + + TypeInferencer::TypeInferencer() { + this->env = EnvironmentNode::make({}); + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); + } + + TypeInferencer::TypeInferencer(Environment env) : env(env) { + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); + } + + CheckedExpr TypeInferencer::Infer(const Expr &expr) { + RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; + CheckedExpr checked_expr = this->VisitExpr(expr); + RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type << std::endl; + Type final_type = this->unifier->subst(checked_expr.type); + RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type << std::endl; + checked_expr.expr->checked_type_ = final_type; + return checked_expr; + } + + CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { + auto var = GetRef(op); + return { var, this->local_stack.lookup(var) }; + } + + CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { + // GlobalVar id = GetRef(op); + // Item item = this->env->lookup(id); + + // if (const OperatorNode *op = item.as()) { + // return op->type; + // } + + // if (const DefnNode *dn = item.as()) { + // Defn def = GetRef(dn); + // return def->type; + // } + + // this->fatal_error("Unhandled case in GlobalId", op->span); + throw Error("hereeee"); + } + + // Type TypeInferencer::VisitExpr_(const OperatorIdNode *op) { + // OperatorId id = GetRef(op); + // Item item = this->env->lookup(id); + + // if (const OperatorNode *pn = item.as()) { + // Operator prim = GetRef(pn); + // return prim->type; + // } else { + // this->fatal_error("internal error in InstrinsicId case", op->span); + // } + // } + + CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { + // Tuple pl = GetRef(op); + + // std::vector field_types; + // for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) + // { + // field_types.push_back(this->Check(*field)); + // } + + // return TupleTypeNode::make(field_types); + throw Error("TupleNode NYI"); + } + + CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *op) { + // Param param = GetRef(op); + // return { resolve(param->type); + throw Error("ParamNode NYI"); + } + + // // We should probably generalize the subst code. + // struct GeneralizeTypeType : TypeFVisitor { + // Map vars_to_id; + // const TypeUnifier &unifier; + + // GeneralizeTypeType(Map vars_to_id, + // const TypeUnifier &unifier) + // : vars_to_id(vars_to_id), unifier(unifier) {} + + // Type VisitType_(const TypeVarNode *op) override { + // auto repr = unifier->subst(GetRef(op)); + // if (auto tvn = repr.as()) { + // auto ty_var = GetRef(tvn); + // if (vars_to_id.find(ty_var) != vars_to_id.end()) { + // return vars_to_id[ty_var]; + // } else { + // return ty_var; + // } + // } else { + // return this->VisitType(repr); + // } + // } + // }; + + // struct GeneralizeTypeExpr : ExprFVisitor<> { + // Map vars_to_id; + // const TypeUnifier &unifier; + + // GeneralizeTypeExpr(const TypeUnifier &unifier, + // Map vars_to_id) + // : vars_to_id(vars_to_id), unifier(unifier) {} + + // Type VisitType(const Type &t) { + // return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); + // } + // }; + + CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { + throw Error("FunctionNode NYI"); + // // enter params into context + // auto fn_type = this->with_frame([&]() { + // std::vector arg_types; + // for (auto arg : f->params) { + // this->Check(arg); + // Type arg_type; + // // if arg type can be simply evaluated, try it + // // should be replaced with symbolic evaluation once it exists, + // // you will not have attr information at this point + // try { + // arg_type = simple_eval_shape(arg->type); + // } catch (const dmlc::Error &e) { + // this->report_error(e.what(), arg->span); + // arg_type = arg->type; + // } + // arg_types.push_back(arg_type); + // this->local_stack.insert(arg->id, arg_type); + // } + + // // typecheck body and ensure that it matches stated return type + // // TODO(sslyu): should the unified return type override the annotated + // one? Type checked_return = this->Check(f->body); Type ret_type = + // resolve(f->ret_type); Type unified = + // this->unify(simple_eval_shape(ret_type), + // simple_eval_shape(checked_return), f->span); + // return TypeArrowNode::make(arg_types, unified); + // }); + // if (generalize) { + // auto free_vars = free_type_vars(resolve(fn_type)); + // std::set dedup_free_vars; + + // for (auto free_var : free_vars) { + // auto repr = this->unifier->subst(free_var); + // if (auto new_free_var_node = repr.as()) { + // dedup_free_vars.insert(GetRef(new_free_var_node)); + // } else { + // // debug(repr); + // throw dmlc::Error( + // "internal error: this list should only contain type var + // nodes"); + // } + // } + + // Map vars_to_id; + + // GenFresh gf; + // for (auto free_var : dedup_free_vars) { + // vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); + // } + + // fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); + // for (std::pair pair : vars_to_id) { + // // NB: In generalization we want to find type variables with + // // *no constraints* on them, and convert them to universally + // quantified + // // variables. + // // + // // i.e the program can be abstracted over the details of *that* type. + + // // For example a program that works irrespective of shape or + // datatype. + + // // In order to do this we find the set of free type variables in the + // // term, and then unify them with the fresh type ids we generate. + // // + // // Remember importantly these type variables still may appear in many + // // places in the program including both types and expressions. + + // // Our method for resolving these is to unify them with the variables + // // as we build the new quanitifer, changing from a program with + // "holes" + // // to one that is properly abstracted over. + + // // Finally later on we can iterate over the whole term and change + // from + // // type variables to these type ids. + // this->unify(pair.first, pair.second, pair.second->span); + // fn_type = TypeQuantifierNode::make(pair.second, fn_type); + // } + // } else { + // for (auto i = f->ty_params.size(); i > 0; i--) { + // auto ty_param = f->ty_params[i - 1]; + // auto ty_param_node = ty_param.as(); + // if (!ty_param_node) { + // throw dmlc::Error("internal error should be TypeParam"); + // } + // auto fresh_tid = + // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); + // fn_type = + // type_subst(fn_type, GetRef(ty_param_node), fresh_tid); + // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); + // } + // } + + // return fn_type; + + } + + CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { + return this->VisitFunction(GetRef(op), false); + } + + // Type TypeInferencer::instantiate(Type t, tvm::Array &ty_args) { + // const TypeQuantifierNode *ty_quant; + // while ((ty_quant = t.as())) { + // TypeParam id = ty_quant->id; + // TypeVar fresh = TypeVarNode::make(id->kind); + // this->unifier->insert(fresh); + // ty_args.push_back(fresh); + // t = type_subst(ty_quant->boundType, id, fresh); + // } + + // if (!check_kind(t)) { + // this->fatal_error("Kind rules broken when instantiating type + // variables", + // t->span); + // } + + // return t; + // } + + CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { + throw Error("CallNode"); + // Call c = GetRef(op); + // Type fn_ty = this->Check(c->fn); + + // RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl + // << "fn_ty=" << fn_ty << std::endl; + + // // for each type id, insert a type variable and unify with the argument + // types + // // in order + // // to obtain the concrete instantiation + // tvm::Array ty_args; + // if (const TypeQuantifierNode *ty_quant = fn_ty.as()) + // { + // fn_ty = instantiate(GetRef(ty_quant), ty_args); + // } + + // if (!fn_ty.as()) { + // this->fatal_error("only expressions with function types can be called", + // c->fn->span); + // } + + // // evaluate all shapes up front (require that types be fully concrete) + // Type evaluated = evaluate_concrete_shape(fn_ty, op->attrs); + // std::vector arg_types; + + // TypeArrow arrow = GetRef(evaluated.as()); + + // // TODO(sslyu): figure out how to handle type ids + // // fn_ty = instantiate(fn_ty, ty_args); + // for (auto arg : c->args) { + // auto ty = this->Check(arg); + // arg_types.push_back(ty); + // } + + // auto type_arity = arrow->arg_types.size(); + // auto number_of_args = arg_types.size(); + // if (type_arity != number_of_args) { + // if (type_arity < number_of_args) { + // this->fatal_error("the function is provided too many arguments", + // c->span); + // } else { + // this->fatal_error("the function is provided too few arguments", + // c->span); + // } + // } + + // for (size_t i = 0; i < arrow->arg_types.size(); i++) { + // this->unify(arrow->arg_types[i], arg_types[i], c->args[i]->span); + // } + + // // After we unify the arguments we should know more about the type + // // arguments, let's run a quick pass over them to find new + // representatives. for (size_t i = 0; i < ty_args.size(); i++) { + // ty_args.Set(i, this->unifier->subst(ty_args[i])); + // } + + // // Write the type arguments into the call node, recording what inference + // // solves. This solution might need some work. + // c->ty_args = ty_args; + + // return arrow->ret_type; + } + + // Type TypeInferencer::VisitExpr_(const DebugNode *op) { + // return this->Check(op->node); + // } + + CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { + Let let = GetRef(op); + + Type checked_ty; + Type annotated_ty = resolve(let->value_type); + + // // if we are let-defining a function, treat it as a let-rec and insert + // // the id with the annotated type in case there is recursion; + // // no such recursion permitted with anything that's not a function! + // if (let->value.as()) { + // with_frame([&]() { + // local_stack.insert(let->id, annotated_ty); + // checked_ty = Check(let->value); + // }); + // } else { + // checked_ty = Check(let->value); + // } + + // ensure annotated type and checked type are compatible + // TODO(sslyu): should the annotated type override the unified one? + Type unified_ty = + this->unify(checked_ty, annotated_ty, let->span); + + return with_frame([&]() { + local_stack.insert(let->var, unified_ty); + return Infer(let->body); }); + } + + // Type TypeInferencer::VisitExpr_(const ReverseNode *op) { + // // apply reverse mode to node and typecheck that instead + // std::shared_ptr gf = std::make_shared(); + // return this->Check(ReverseExpr(env, op->node, gf)); + // } + + // Type TypeInferencer::VisitExpr_(const GradientNode *op) { + // auto node = op->node; + // this->Check(node); + // auto gf = std::make_shared(); + // return FOWithGradientType(node->checked_type()); + // } + + // Type TypeInferencer::VisitExpr_(const ProjectionNode *op) { + // Projection proj = GetRef(op); + + // Type tup_type = this->Check(proj->tuple); + + // const TupleTypeNode *ptn = tup_type.as(); + // if (!ptn) { + // this->fatal_error("Cannot project into non-product type", op->span); + // } + + // TupleType pt = GetRef(ptn); + // size_t field = (size_t)proj->field; + // if (field >= pt->fields.size()) { + // this->fatal_error("Projecting past bounds of product", op->span); + // } + + // return pt->fields[field]; + // } + + CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { + // If ifn = GetRef(op); + + // // Ensure the type of the guard is of Tensor[Bool, ()], + // // that is a rank-0 boolean tensor. + // Type guardType = this->Check(ifn->guard); + // bool is_bool = false; + // bool zero_rank = false; + // if (const TensorTypeNode *ttn = guardType.as()) { + // TensorType tt = GetRef(ttn); + + // if (const BaseTypeNode *btn = tt->dtype.as()) { + // is_bool = btn->type.is_bool(); + // } + + // Type shape = simple_eval_shape(tt->shape); + + // if (const ShapeSeqNode *sn = shape.as()) { + // zero_rank = (sn->shapes.size() == 0); + // } + // } + + // if (!(is_bool && zero_rank)) { + // this->fatal_error("IfNode guard must be a rank 0 bool tensor", + // ifn->guard->span); + // } + + // // unify types of different branches + // Type left = this->Check(ifn->true_b); + // Type right = this->Check(ifn->false_b); + // return this->unify(left, right, ifn->span); + } + + // Type TypeInferencer::VisitExpr_(const RefNode *op) { + // Ref r = GetRef(op); + // Type inner = this->Check(r->expr); + // return RefTypeNode::make(inner); + // } + + // Type TypeInferencer::VisitExpr_(const ReadRefNode *op) { + // ReadRef vr = GetRef(op); + // Type ref_type = this->Check(vr->ref); + + // // reject if not a ref type + // const RefTypeNode *rtn = ref_type.as(); + // if (!rtn) { + // this->fatal_error( + // "the de-reference operation can only be used with references", + // op->span); + // } + + // RefType rt = GetRef(rtn); + // return rt->data_type; + // } + + // Type TypeInferencer::VisitExpr_(const WriteRefNode *op) { + // WriteRef sr = GetRef(op); + // Type ref_type = this->Check(sr->ref); + + // const RefTypeNode *rtn = ref_type.as(); + // if (!rtn) { + // this->fatal_error("Cannot mutate non-ref", op->span); + // } + // RefType rt = GetRef(rtn); + + // // ensure ref type's inner type and expr's type are compatible; return + // unit Type expr_type = this->Check(sr->val); this->unify(rt->data_type, + // expr_type, sr->span); return UnitType(); + // } + + Type TypeInferencer::resolve(const Type &t) { + return ::tvm::relay::resolve(this->unifier, t); + } + + Expr TypeInferencer::resolve(const Expr &e) { + return ::tvm::relay::resolve(this->unifier, e); + } + + // Operator TypeInferencer::CheckOp(Operator op) { + // if (!check_kind(op->type)) { + // report_error("the type of the operator is ill formed", op->type->span); + // } + + // // Fix me + // return op; + // } + + // Defn TypeInferencer::CheckDefn(Defn defn) { + // // This is to handle recursion, but we need to speculatively + // // put it in env, then remove it. + // env->items.insert({defn->id, defn}); + + // Type expected_ty = this->resolve(defn->type); + + // Expr body = defn->body; + + // auto checked_ty = Check(body); + + // try { + // Type uret_type = unify(expected_ty, checked_ty, defn->body->span); + // CHECK(is_fully_resolved(uret_type)); + // // Now let's clean up our work from earlier. + // env->items.erase(defn->id); + // return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); + // } catch (const UnificationError& err) { + // std::string msg = std::string("mismatch between `") + + // PrintType(env, expected_ty, WrapWidth(40)) + "` and + // `" + PrintType(env, checked_ty, WrapWidth(40)) + + // "`"; + // fatal_error(msg, defn->span); + // } + // } + + Expr Infer(const Environment &env, const Expr &e) { + TypeInferencer ti(env); + auto checked_expr = ti.Infer(e); + return checked_expr.expr; + } + + // Item Check(const Environment &env, const Item &i) { + // TypeInferencer tc(env); + + // try { + // if (const DefnNode *defn = i.as()) { + // return tc.CheckDefn(GetRef(defn)); + // } else if (const OperatorNode *op_node = i.as()) { + // return tc.CheckOp(GetRef(op_node)); + // } else { + // throw dmlc::Error("internal error: unknown Item type"); + // } + // } catch (const FatalTypeError &err) { + // env->display_errors(); + // throw dmlc::Error( + // "We encountered a fatal error while type checking your program, + // please " "read above for more details."); + // } + // } + + inline void TypeInferencer::report_error(const std::string &msg, Span sp) { + // this->env->report_error(msg, sp); + } + + void TypeInferencer::fatal_error(const std::string &msg, Span sp) { + // this->env->report_error(msg, sp); + throw FatalTypeError( + "internal error: this exception should" + "be handled and errors reported with Environment::display_errors\n" + + msg); + } + + Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { + try { + return this->unifier->unify(t1, t2); + } catch (const dmlc::Error &e) { + std::stringstream ss; + ss << "Error unifying `"; + ss << t1; + // ss << PrintType(env, t1, WrapWidth(40)); + ss << "` and `"; + ss << t2; + // ss << PrintType(env, t2, WrapWidth(40)); + ss << "`: " << e.what(); + this->fatal_error(ss.str(), sp); + } + } + + // // template + + // // Add safe dynamic Array downcast. + // // Add static upcast? + + // // Add to type utils. + // Array type_parameters(const Type &t) { + // Array params; + // auto type = t; + // const TypeQuantifierNode *ty_quant; + // while ((ty_quant = type.as())) { + // params.push_back(ty_quant->id); + // type = ty_quant->boundType; + // } + + // return params; + // } + + // template + // Array ArrayMap(const Array &data, F f) { + // // probably a way to use std::transform. + // Array output; + // for (const I &el : data) { + // output.push_back(f(el)); + // } + // return output; + // } + + // // There are some important questions around generalization + // // that we need to answer. + // Expr generalize(const Environment &env, const Expr &e) { + // if (auto fn_node = e.as()) { + // TypeInferencer tc(env); + // auto ty = tc.VisitFunction(GetRef(fn_node), true); + // auto ty_params = type_parameters(ty); + // auto params = ArrayMap(fn_node->params, [&](const Param &p) { + // return ParamNode::make(p->id, tc.resolve(p->type)); + // }); + // auto body = tc.resolve(fn_node->body); + // auto ret_type = tc.resolve(fn_node->ret_type); + // auto fn = FunctionNode::make(ty_params, params, ret_type, body); + // // we should check in empty context to ensure typing is preserved. + // // check(env, fn); + // return fn; + // } else { + // throw dmlc::Error("can only apply generalize to a function."); + // } + // } + + TVM_REGISTER_API("relay._type_infer.check_expr") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + Expr e = args[1]; + *ret = Infer(env, e); + }); + + // TVM_REGISTER_API("relay._tyck.check_item") + // .set_body([](TVMArgs args, TVMRetValue *ret) { + // Environment env = args[0]; + // Item i = args[1]; + // *ret = check(env, i); + // }); + + TVM_REGISTER_API("relay._type_infer._get_checked_type") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Expr e = args[0]; + *ret = e->checked_type(); + }); + + // TVM_REGISTER_API("relay._tyck.generalize") + // .set_body([](TVMArgs args, TVMRetValue *ret) { + // *ret = generalize(args[0], args[1]); + // }); + + IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { + std::shared_ptr n = + std::make_shared(); + n->kind = std::move(kind); + return IncompleteType(n); + } + + TVM_REGISTER_API("relay._make.IncompleteType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[0]; + *ret = IncompleteTypeNode::make(static_cast(kind)); + }); + + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IncompleteTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); } // namespace relay -} // namespace tvm +} // namespace relay diff --git a/src/relay/compiler/unifier.h b/src/relay/compiler/unifier.h index cba96ff02451..86ffd664a161 100644 --- a/src/relay/compiler/unifier.h +++ b/src/relay/compiler/unifier.h @@ -99,9 +99,12 @@ class TypeUnifierNode : public Node, TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); private: - // unify non-typevar with typevar + /*! \brief Unify incomplete type with another type. */ Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); + /*! \brief Implements unification between two types with incomplete portions. */ Type VisitType(const Type & t1, const Type t2) override; + + // Visitor Cases Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; Type VisitType_(const TensorTypeNode* t1, const Type t2) override; Type VisitType_(const TypeParamNode* t1, const Type t2) override; From b5e655af2bcb7fabce4271e541f40fb5579d1b95 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 02:30:53 -0700 Subject: [PATCH 27/88] First simple test passes --- src/relay/compiler/type_infer.cc | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 49c8bbf9627f..7304bdabe486 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -108,6 +108,7 @@ class TypeInferencer : private ExprFunctor { private: CheckedExpr VisitExpr_(const LocalVarNode* op) override; CheckedExpr VisitExpr_(const GlobalVarNode* op) override; + CheckedExpr VisitExpr_(const ConstantNode* op) override; CheckedExpr VisitExpr_(const TupleNode* op) override; CheckedExpr VisitExpr_(const ParamNode* op) override; CheckedExpr VisitExpr_(const FunctionNode* op) override; @@ -157,6 +158,15 @@ class TypeInferencer : private ExprFunctor { throw Error("hereeee"); } + CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { + auto array = const_node->data; + // array->t + // first pass + return { + GetRef(const_node), + TensorTypeNode::make({}, HalideIR::Float(32, 1)) }; + } + // Type TypeInferencer::VisitExpr_(const OperatorIdNode *op) { // OperatorId id = GetRef(op); // Item item = this->env->lookup(id); @@ -423,16 +433,16 @@ class TypeInferencer : private ExprFunctor { Type checked_ty; Type annotated_ty = resolve(let->value_type); - // // if we are let-defining a function, treat it as a let-rec and insert - // // the id with the annotated type in case there is recursion; - // // no such recursion permitted with anything that's not a function! + // if we are let-defining a function, treat it as a let-rec and insert + // the id with the annotated type in case there is recursion; + // no such recursion permitted with anything that's not a function! // if (let->value.as()) { - // with_frame([&]() { - // local_stack.insert(let->id, annotated_ty); - // checked_ty = Check(let->value); - // }); + // with_frame([&]() { + // local_stack.insert(let->id, annotated_ty); + // checked_ty = Check(let->value); + // }); // } else { - // checked_ty = Check(let->value); + checked_ty = Infer(let->value).type; // } // ensure annotated type and checked type are compatible From 9cdfb9a0c59e5c2ddcebe8293d810df43de9a573 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 10:37:34 -0700 Subject: [PATCH 28/88] Iterate towards second test --- include/tvm/relay/compiler/environment.h | 13 +++--- include/tvm/relay/op.h | 1 + python/tvm/relay/ir_builder.py | 54 +++++++++++++++++++++--- src/relay/compiler/environment.cc | 20 ++++----- tests/python/relay/test_typechecker.py | 11 ++++- 5 files changed, 72 insertions(+), 27 deletions(-) diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index ddb7f0dca192..3e108cd8b390 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -42,9 +42,9 @@ class EnvironmentNode : public RelayNode { /*! A map from string names to GlobalIds, ensures global uniqueness. */ InternTable global_map_; /*! A map from string names to Operators, ensures global uniqueness. */ - InternTable operator_map_; + InternTable operators; // /*! \brief A map from file names to source fragments. */ - // SourceMap source_map_; + // SourceMap source_map_ // /*! \brief A list of the errors reported during the current run. */ // std::vector errors_; @@ -64,8 +64,8 @@ class EnvironmentNode : public RelayNode { TVM_DLL static Environment make( std::unordered_map global_funcs); - // Add an item to the Enviroment. - // void add(const Operator& op, bool update = false); + /*! Add an operator to the Enviroment. */ + void register_op(const Operator& op); // void add(const Operator& op, bool update = false); // void try_add(const Item& item, bool update=false); @@ -73,13 +73,10 @@ class EnvironmentNode : public RelayNode { // void remove(const GlobalId& id); // GlobalId global_id(const std::string& str); - // OperatorId operator_id(const std::string& str); + Operator op(const std::string& str); // We can lookup a GlobalId, OperatorId. // Defn lookup(const GlobalId& id); - // Operator lookup(const OperatorId& id); - // Defn lookup_global(const std::string& str); - // Item lookup_operator(const std::string& str); // FileId add_source(std::string file_name, std::string source); // tvm::Array get_operators(); diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index fa152945d38c..2e631df55fd0 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -25,6 +25,7 @@ class Operator; /*! \brief Container for Operator */ class OperatorNode : public ExprNode { public: + std::string name; /*! \brief A type which specifies the relationship between the inputs and outputs * of the operator. */ diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 2b2cdb432b43..35bc31265987 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -5,6 +5,12 @@ from . import expr from . import make as mk +class ExprBuilder(): + def __init__(self, expr): + self.expr = expr + + def __call__(self, *args): + return ExprBuilder(mk.Call(self.expr, list(args), None, None)) def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: """Convert Python values into the appropriate types @@ -29,7 +35,7 @@ def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> expr.Expr: raise Exception("..") else: value = convert(arg, ctxt) - return mk.Constant(value) + return ExprBuilder(mk.Constant(value)) class WithScope(object): """Auxiliary scope with""" @@ -44,6 +50,18 @@ def __enter__(self): def __exit__(self, ptype, value, trace): self._exit_cb() + +class PartialFunc(): + def __init__(self, params, ret_type, body, type_params): + self.params = params + self.ret_type = ret_type + self.body = body + self.type_params = type_params + + def param_ids(self): + return [p.var for p in self.params] + + def _mk_let(bindings, ret_value): let_expr = ret_value for var, value in reversed(list(bindings.items())): @@ -51,12 +69,15 @@ def _mk_let(bindings, ret_value): return let_expr + class IRBuilder(): def __init__(self): self.bindings = [{}] self.scopes = [{}] + self.params = [] self.ret_value = None + def bind(self, name, type, value): lv = mk.LocalVar(name) self.scopes[-1][name] = lv @@ -65,18 +86,33 @@ def bind(self, name, type, value): def let(self, name, value, value_type=None): - if not isinstance(value, expr.Expr): + if not (isinstance(value, expr.Expr) or isinstance(value, ExprBuilder)): value = into_ast(value) + if isinstance(value, ExprBuilder): + value = value.expr + return self.bind(name, value_type, value) - def function(self, params): + def function(self, *params): + relay_params = [] + for name, ty in params: + lv = mk.LocalVar(name) + self.scopes[-1][name] = lv + relay_params.append(mk.Param(lv, ty)) + + # self.params.append(relay_params) + + pfunc = PartialFunc(relay_params, None, None, []) + def _on_exit(): bindings = self.bindings.pop() scope = self.scopes.pop() - import pdb - pdb.set_trace() - return WithScope(None, _on_exit) + # params = self.params.pop() + + + return WithScope(pfunc, _on_exit) + def ret(self, x): if not self.ret_value: @@ -85,6 +121,12 @@ def ret(self, x): raise Exception( "return value already set, a function can only have one return value") + def fn_params(self): + pass + + def op(self, name): + pass + def get(self): """Get the full program""" bindings = self.bindings.pop() diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc index af8f5eeefab7..735ef79ceb3a 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/compiler/environment.cc @@ -24,6 +24,10 @@ Environment EnvironmentNode::make( return Environment(n); } +void EnvironmentNode::register_op(const Operator& op) { + this->operators.Insert(op->name, op); +} + // tvm::PackedFunc EnvironmentNode::jit_for(OperatorId id) { // return this->lookup(id)->compiler; // } @@ -111,18 +115,10 @@ Environment EnvironmentNode::make( // } // } -// Operator EnvironmentNode::lookup(const OperatorId &id) { -// if (operators.find(id) != operators.end()) { -// return operators.at(id); -// } else { -// throw EnvError(std::string("there is no definition of ") + id->name); -// } -// } - -// Item EnvironmentNode::lookup_operator(const std::string &str) { -// OperatorId id = this->operator_id(str); -// return lookup(id); -// } +Operator EnvironmentNode::op(const std::string & op_name) { + // FIX ME + return operators.Lookup(op_name); +} // Defn EnvironmentNode::lookup_global(const std::string &str) { // GlobalId id = this->global_id(str); diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index 5626fd8ce0bc..c6bc1a05ebe9 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -11,8 +11,8 @@ def has_type(expr, typ): return checked_expr.checked_type() == typ def test_monomorphic_let(): + "Program: let x = 1; x" b = IRBuilder() - # Program: let x = 1; x x = b.let('x', 1, value_type=float_type()) b.ret(x) @@ -20,3 +20,12 @@ def test_monomorphic_let(): assert has_type(prog, float_type()) +def test_single_op(): + "Program: fn (x : int32) { let t1 = f(x); t1 }" + b = IRBuilder() + f = b.op('f') + with b.function(('x', float_type())) as func: + x, = func.param_ids() + t1 = b.let('t1', f(x)) + b.ret(t1) + import pdb; pdb.set_trace() From 1cfc2715861cfa1e13cd748b43ef621469e93eab Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 16:55:31 -0700 Subject: [PATCH 29/88] Remove placeholder op.h --- include/tvm/relay/op.h | 48 ------------------------------------------ 1 file changed, 48 deletions(-) delete mode 100644 include/tvm/relay/op.h diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h deleted file mode 100644 index 2e631df55fd0..000000000000 --- a/include/tvm/relay/op.h +++ /dev/null @@ -1,48 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/op.h - * \brief Relay's representation of operators. - */ -#ifndef TVM_RELAY_OP_H_ -#define TVM_RELAY_OP_H_ - -#include "./expr.h" - -namespace tvm { -namespace relay { - - -/*! - * \brief A primitive Relay operator defined externally to Relay. - * - * \note Currently these are expected to be backed by a TVM's operator, - * such as the ones defined in TOPI. - * - * For developers who are familar with the computational graph this - * directly maps to the concept of operators in NNVM. - */ -class Operator; -/*! \brief Container for Operator */ -class OperatorNode : public ExprNode { - public: - std::string name; - /*! \brief A type which specifies the relationship between the inputs and outputs - * of the operator. - */ - Type op_type; - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("op_type", &op_type); - } - - TVM_DLL static Operator make(Type op_type); - - static constexpr const char* _type_key = "relay.Operator"; - TVM_DECLARE_NODE_TYPE_INFO(OperatorNode, OperatorNode); -}; - -RELAY_DEFINE_NODE_REF(Operator, OperatorNode, Expr); - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_EXPR_H_ From 491880e7224188381a70a0296ebd5c3db1c9d813 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 22 Aug 2018 09:44:43 -0700 Subject: [PATCH 30/88] [OP] Current op system --- include/tvm/base.h | 5 + include/tvm/relay/op.h | 402 ++++++++++++++++++++++++++++ python/tvm/relay/__init__.py | 3 +- python/tvm/relay/make.py | 6 + python/tvm/relay/op.py | 37 +++ src/relay/op.cc | 105 ++++++-- src/relay/op/tensor/elemwise.cc | 23 ++ tests/python/relay/test_relay_op.py | 8 + 8 files changed, 568 insertions(+), 21 deletions(-) create mode 100644 include/tvm/relay/op.h create mode 100644 python/tvm/relay/op.py create mode 100644 src/relay/op/tensor/elemwise.cc create mode 100644 tests/python/relay/test_relay_op.py diff --git a/include/tvm/base.h b/include/tvm/base.h index c2d796b6002c..be848b34cd43 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -134,6 +134,11 @@ struct NodeFactoryReg { */ #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) +#define TVM_REGISTER_NODE_TYPE(TypeName) \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ + ::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \ + .set_body([]() { return std::make_shared(); }) + } // namespace tvm #endif // TVM_BASE_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h new file mode 100644 index 000000000000..f7e1cfbbc8c2 --- /dev/null +++ b/include/tvm/relay/op.h @@ -0,0 +1,402 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op.h + * \brief Primitive operator definition. + */ +#ifndef TVM_RELAY_OP_H_ +#define TVM_RELAY_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "./base.h" +#include "./expr.h" +#include "../attrs.h" + +namespace tvm { +namespace relay { + +// forward declare name. +template +class OpMap; +class GenericOpMap; +class OpRegistry; + +/*! + * \brief Node container of operator structure. + */ +class OpNode : public relay::ExprNode { + public: + /*! \brief name of the operator */ + std::string name; + /*! + * \brief detailed description of the operator + * This can be used to generate docstring automatically for the operator. + */ + std::string description; + /* \brief Information of input arguments to the operator */ + Array arguments; + /*! + * \brief The type key of the attribute field + * This can be empty, in which case it defaults to + */ + std::string attrs_type_key; + /*! + * \brief number of input arguments to the operator, + * -1 means it is variable length + */ + int32_t num_inputs = -1; + /*! + * \brief support level of the operator, + * The lower the more priority it contains. + * This is in analogies to BLAS levels. + */ + int32_t support_level = 10; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("description", &description); + v->Visit("arguments", &arguments); + v->Visit("attrs_type_key", &attrs_type_key); + v->Visit("num_inputs", &num_inputs); + v->Visit("support_level", &support_level); + } + + static constexpr const char* _type_key = "relay.Op"; + TVM_DECLARE_NODE_TYPE_INFO(OpNode, Node); + + private: + // friend class + friend class GenericOpMap; + friend class OpRegistry; + // Program internal unique index of operator. + // Used to help index the program. + uint32_t index_{0}; +}; + +/*! + * \brief Operator reference class. + */ +class Op : public relay::Expr { + public: + /*! \brief default constructor */ + Op() {} + /*! \brief constructor from node pointer */ + explicit Op(std::shared_ptr n) : Expr(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const OpNode* operator->() const; + /*! + * \brief Get additional registered attribute about operators. + * If nothing has been registered, an empty OpMap will be returned. + * \param attr_name The name of the attribute. + * \return An OpMap of specified attr_name. + * \tparam ValueType The type of the attribute. + */ + template + inline static const OpMap& GetAttr(const std::string& attr_name); + /*! + * \brief Get an Op for a given operator name. + * Will raise an error if the op has not been registered. + * \param op_name Name of the operator. + * \return Pointer to a Op, valid throughout program lifetime. + */ + TVM_DLL static const Op& Get(const std::string& op_name); + + /*! \brief specify container node */ + using ContainerType = OpNode; + + private: + /*! + * \brief Get generic attrmap given attr name + * \param key The attribute key + * \return reference to GenericOpMap + */ + TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); +}; + +/*! \brief Helper structure to register operators */ +class OpRegistry { + public: + /*! \return the operator */ + const Op& op() const { + return op_; + } + /*! + * \brief setter function during registration + * Set the description of operator + * \param descr the description string. + * \return reference to self. + */ + inline OpRegistry& describe(const std::string& descr); // NOLINT(*) + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline OpRegistry& add_argument(const std::string &name, + const std::string &type, + const std::string &description); + /*! + * \brief Set the type key of attributes. + * \param type_key The type of of the attrs field.x + * \return reference to self. + */ + inline OpRegistry& set_attrs_type_key(const std::string& type_key); + /*! + * \brief Set the num_inputs + * \param n The number of inputs to be set. + * \return reference to self. + */ + inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) + /*! + * \brief Set the support level of op. + * \param level The support level. + * \return reference to self. + */ + inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) + /*! + * \brief Register additional attributes to operator. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, + int plevel = 10); + + // set the name of the op to be the same as registry + inline OpRegistry& set_name() { // NOLINT(*) + get()->name = name; + return *this; + } + + private: + friend class ::dmlc::Registry; + // the name + std::string name; + /*! \brief The operator */ + Op op_; + // private constructor + OpRegistry(); + // return internal pointer to op. + inline OpNode* get(); + // update the attribute OpMap + TVM_DLL void UpdateAttr(const std::string& key, + TVMRetValue value, + int plevel); +}; + +/*! + * \brief Generic map to store additional information of Op. + */ +class GenericOpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline const TVMRetValue& operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + inline ValueType get(const Op& op, ValueType def_value) const; + + private: + friend class OpRegistry; + // the attribute field. + std::string attr_name_; + // internal data + std::vector > data_; + // The value + GenericOpMap() = default; +}; + +/*! + * \brief Map used to store meta-information about Op. + * \tparam ValueType The type of the value stored in map. + */ +template +class OpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline ValueType operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + inline ValueType get(const Op& op, ValueType def_value) const; + + private: + friend class Op; + // constructor + explicit OpMap(const GenericOpMap& map) + : map_(map) {} + /*! \brief The internal map field */ + const GenericOpMap& map_; +}; + + +// internal macros to make +#define RELAY_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry & __make_ ## RelayOp + +/*! + * \def RELAY_REGISTER_OP + * \brief Register a new operator, or set attribute of the corresponding op. + * + * \param OpName The name of registry + * + * \code + * + * RELAY_REGISTER_OP("add") + * .describe("add two inputs together") + * .set_num_inputs(2) + * .set_attr("gpu_kernel", AddKernel); + * + * \endcode + */ +#define RELAY_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ + ::dmlc::Registry<::tvm::relay::OpRegistry>::Get()->__REGISTER_OR_GET__(OpName).set_name() + +// implementations +inline const OpNode* Op::operator->() const { + return static_cast(node_.get()); +} + +template +inline const OpMap& Op::GetAttr(const std::string& key) { + return OpMap(Op::GetGenericAttr(key)); +} + +inline OpNode* OpRegistry::get() { + return const_cast(op_.operator->()); +} + +inline OpRegistry& OpRegistry::describe(const std::string& descr) { // NOLINT(*) + get()->description = descr; + return *this; +} + +inline OpRegistry& OpRegistry::add_argument(const std::string &name, + const std::string &type, + const std::string &description) { + std::shared_ptr n = std::make_shared(); + n->name = name; + n->type_info = type; + n->description = description; + get()->arguments.push_back(AttrFieldInfo(n)); + return *this; +} + +inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) + get()->num_inputs = n; + return *this; +} + +inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) + get()->support_level = n; + return *this; +} + +template +inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) + const std::string& attr_name, + const ValueType& value, + int plevel) { + CHECK_GT(plevel, 0) + << "plevel in set_attr must be greater than 0"; + TVMRetValue rv; + rv = value; + UpdateAttr(attr_name, rv, plevel); + return *this; +} + +// member functions of OpMap +inline int GenericOpMap::count(const Op& op) const { + if (op.defined()) { + const uint32_t idx = op->index_; + return idx < data_.size() ? (data_[idx].second != 0) : 0; + } else { + return 0; + } +} + +inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + CHECK(idx < data_.size() && data_[idx].second != 0) + << "Attribute " << attr_name_ + << " has not been registered for Operator " << op->name; + return data_[idx].first; +} + +template +inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return value; + } +} + +template +inline int OpMap::count(const Op& op) const { + return map_.count(op); +} + +template +inline ValueType OpMap::operator[](const Op& op) const { + return map_[op]; +} +template +inline ValueType OpMap::get(const Op& op, ValueType def_value) const { + return map_.get(op, def_value); +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c90875db4178..a9446ebed979 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,7 +1,8 @@ -"""Relay namespace.""" +"""The Relay IR namespace containing the IR definition and compiler.""" from . import base from . import type as tpe from . import make +from . import op # Type Type = tpe.Type diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py index bf9ec0e48f64..8cad18c11c46 100644 --- a/python/tvm/relay/make.py +++ b/python/tvm/relay/make.py @@ -1,6 +1,12 @@ from . import _make from . import ir +This module includes MyPy type signatures for all of the +exposed modules. +""" +from __future__ import absolute_import as _abs +from .._ffi.function import _init_api + # Base Constructors Span = _make.Span diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py new file mode 100644 index 000000000000..373d04f984a3 --- /dev/null +++ b/python/tvm/relay/op.py @@ -0,0 +1,37 @@ +"""Relay operators""" +from __future__ import absolute_import as _abs + +import sys +from .._ffi.function import _init_api +from .._ffi.node import convert_to_node +from . import make as _make +from ..make import node as _make_node + +def _create_op(op_name): + op = _GetOp(op_name) + attrs_type_key = op.attrs_type_key + attrs_type_key = attrs_type_key if attrs_type_key else "DictAttrs" + # TODO(tqchen): improve the code build to fix the restriction. + # + # current restriction: + # - pass in args as positional arguments + # - pass in kwargs as keyword argument + def _op_func(*args, **kwargs): + args = convert_to_node(args) + # Need work to make sure constructor matches + return _make.Call(op, args, + attrs = _make.node(attrs_type_key, **kwargs)) + _op_func.__name__ = op.name + return _op_func + + +def _init_ops(): + """Helper function to initialize the operators + """ + module = sys.modules[__name__] + for name in _ListOpNames(): + f = _create_op(name.value) + setattr(module, f.__name__, f) + +_init_api("relay.op", __name__) +_init_ops() diff --git a/src/relay/op.cc b/src/relay/op.cc index 07ad5f0ae4ed..5a4241a182b1 100644 --- a/src/relay/op.cc +++ b/src/relay/op.cc @@ -1,31 +1,96 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file op.cc - * \brief Relay's representation of operators. - */ -#include "tvm/relay/op.h" -#include "tvm/ir_functor.h" +#include +#include +#include + +namespace dmlc { +// enable registry +DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); +} // namespace dmlc namespace tvm { namespace relay { -using tvm::IRPrinter; -using namespace runtime; +// single manager of operator information. +struct OpManager { + // mutex to avoid registration from multiple threads. + std::mutex mutex; + // global operator counter + std::atomic op_counter{0}; + // storage of additional attribute table. + std::unordered_map > attr; + // get singleton of the + static OpManager* Global() { + static OpManager inst; + return &inst; + } +}; + +// find operator by name +const Op& Op::Get(const std::string& name) { + const OpRegistry* reg = dmlc::Registry::Find(name); + CHECK(reg != nullptr) + << "Operator " << name << " is not registered"; + return reg->op(); +} + +OpRegistry::OpRegistry() { + OpManager* mgr = OpManager::Global(); + std::shared_ptr n = std::make_shared(); + n->index_ = mgr->op_counter++; + op_ = Op(n); +} + +// Get attribute map by key +const GenericOpMap& Op::GetGenericAttr(const std::string& key) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + auto it = mgr->attr.find(key); + if (it == mgr->attr.end()) { + LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered"; + } + return *it->second.get(); +} -Operator OperatorNode::make(Type op_type) { - std::shared_ptr n = std::make_shared(); - n->op_type = std::move(op_type); - return Operator(n); +void OpRegistry::UpdateAttr( + const std::string& key, TVMRetValue value, int plevel) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + std::unique_ptr& op_map = mgr->attr[key]; + if (op_map == nullptr) { + op_map.reset(new GenericOpMap()); + } + uint32_t index = op_->index_; + if (op_map->data_.size() <= index) { + op_map->data_.resize(index + 1, + std::make_pair(TVMRetValue(), 0)); + } + std::pair & p = op_map->data_[index]; + CHECK(p.second != plevel) + << "Attribute " << key + << " of operator " << this->name + << " is already registered with same plevel=" << plevel; + if (p.second < plevel) { + op_map->data_[index] = std::make_pair(value, plevel); + } } -TVM_REGISTER_API("relay._make.Operator").set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = OperatorNode::make(args[0]); -}); +// Frontend APIs +using runtime::TypedPackedFunc; + +TVM_REGISTER_API("relay.op._ListOpNames") +.set_body(TypedPackedFunc()>([]() { + Array ret; + for (const std::string& name : + dmlc::Registry::ListAllNames()) { + ret.push_back(tvm::Expr(name)); + } + return ret; + })); + +TVM_REGISTER_API("relay.op._GetOp") +.set_body(TypedPackedFunc(Op::Get)); + -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const OperatorNode *node, tvm::IRPrinter *p) { - p->stream << "OperatorNode(" << node->op_type << ")"; - }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc new file mode 100644 index 000000000000..8b759bfbc07c --- /dev/null +++ b/src/relay/op/tensor/elemwise.cc @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file elemwise.cc + * \brief Elementwise operators. + */ +#include + +namespace tvm { +namespace relay { + +RELAY_REGISTER_OP("log") +.describe(R"code(Returns the log input array, computed element-wise. + +.. math:: + log(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor."); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_relay_op.py b/tests/python/relay/test_relay_op.py new file mode 100644 index 000000000000..93316da8ec41 --- /dev/null +++ b/tests/python/relay/test_relay_op.py @@ -0,0 +1,8 @@ +from tvm import relay + +def test_op_level1(): + assert relay.op.log + + +if __name__ == "__main__": + test_op_level1() From 0e6cdcc101b857a2159e12e6f3399fac31092db8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 17:04:04 -0700 Subject: [PATCH 31/88] WIP --- include/tvm/relay/compiler/environment.h | 24 ++++------ include/tvm/relay/compiler/intern_table.h | 55 ----------------------- include/tvm/relay/type.h | 2 +- 3 files changed, 10 insertions(+), 71 deletions(-) delete mode 100644 include/tvm/relay/compiler/intern_table.h diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index 3e108cd8b390..536302c31dc6 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -41,22 +41,18 @@ class EnvironmentNode : public RelayNode { private: /*! A map from string names to GlobalIds, ensures global uniqueness. */ InternTable global_map_; - /*! A map from string names to Operators, ensures global uniqueness. */ - InternTable operators; + // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ // /*! \brief A list of the errors reported during the current run. */ // std::vector errors_; public: - // This map contains all items *except* operators. - std::unordered_map items; + /*! \brief A map from ids to all global functions. */ + tvm::Map items; // Options options; - tvm::PackedFunc jit_for(Operator op); - tvm::PackedFunc reverse(Operator op); - EnvironmentNode() {} void VisitAttrs(tvm::AttrVisitor* v) final {} @@ -75,16 +71,14 @@ class EnvironmentNode : public RelayNode { // GlobalId global_id(const std::string& str); Operator op(const std::string& str); - // We can lookup a GlobalId, OperatorId. - // Defn lookup(const GlobalId& id); - // FileId add_source(std::string file_name, std::string source); + /*! \brief Lookup a global function by its name. */ + Function lookup(const GlobalVar& id); - // tvm::Array get_operators(); - // tvm::Array get_defns(); + /*! \brief Add a source fragment to the environment. */ + // FileId add_source(std::string file_name, std::string source); - // void report_error(std::string msg, Span sp); - // void display_errors(); - // void register_shape_ext(ShapeExtension ext); + void report_error(std::string msg, Span sp); + void display_errors(); static constexpr const char* _type_key = "relay.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); diff --git a/include/tvm/relay/compiler/intern_table.h b/include/tvm/relay/compiler/intern_table.h deleted file mode 100644 index 1850e513e5e5..000000000000 --- a/include/tvm/relay/compiler/intern_table.h +++ /dev/null @@ -1,55 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/compiler/intern_table.h - * \brief A table which maps string keys to data. - * - * These are useful for mapping user-readable names - * to globally unique allocations which use pointer - * equality for comparsion. - */ -#ifndef TVM_RELAY_COMPILER_INTERN_TABLE_H_ -#define TVM_RELAY_COMPILER_INTERN_TABLE_H_ - -#include -#include -#include "dmlc/logging.h" - -namespace tvm { -namespace relay { - -struct KeyNotFound : dmlc::Error { - explicit KeyNotFound(std::string msg) : dmlc::Error(msg) {} -}; - -template -class InternTable { -private: - /*! \brief The internal table mapping from strings to T. */ - std::unordered_map table_; - - public: - /*! \brief Insert a new key into the table. - * \note Attempting to reinsert a key triggers an error. - */ - void Insert(const std::string& key, const T& value) { - if (table_.find(key) == table_.end()) { - table_.insert({key, value}); - } else { - throw dmlc::Error( - std::string("you have previously interred a value for: ") + key); - } - } - - /*! \brief Lookup the data in the table. */ - const T& Lookup(std::string key) { - if (table_.find(key) != table_.end()) { - return table_.at(key); - } else { - throw KeyNotFound(std::string("could not find match") + key); - } - } -}; - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_COMPILER_INTERN_TABLE_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 07b047471aba..ef8c4c71f5b7 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -92,7 +92,7 @@ class TensorTypeNode : public BaseTensorTypeNode { /*! \brief Construct a floating-point type */ TVM_DLL static TensorType Float(int bits, int lanes = 1); - /*1 \brief Construct a boolean type */ + /*! \brief Construct a boolean type */ TVM_DLL static TensorType Bool(int lanes = 1); static constexpr const char* _type_key = "relay.TensorType"; From 8520a30c9b84430cdb1a44c7ad8dfe8ea7d0a886 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Aug 2018 13:04:01 -0700 Subject: [PATCH 32/88] WIP --- cmake/config.cmake | 3 + include/tvm/relay/compiler/environment.h | 20 +++--- include/tvm/relay/compiler/type_infer.h | 2 +- include/tvm/relay/expr_functor.h | 4 +- include/tvm/relay/expr_visitor.h | 6 +- src/relay/compiler/environment.cc | 88 +++++++----------------- 6 files changed, 41 insertions(+), 82 deletions(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index c364a88cce11..e09fdb241bf1 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -19,6 +19,9 @@ # $ make -j8 #-------------------------------------------------------------------- +SET(CMAKE_C_COMPLIER clang) +SET(CMAKE_CXX_COMPILER clang++) + #--------------------------------------------- # Backend runtimes. #--------------------------------------------- diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index 536302c31dc6..5b33e781b399 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -8,7 +8,6 @@ #include #include -#include "tvm/relay/compiler/intern_table.h" #include "../expr.h" #include "../type.h" #include "../op.h" @@ -40,7 +39,7 @@ struct Environment; class EnvironmentNode : public RelayNode { private: /*! A map from string names to GlobalIds, ensures global uniqueness. */ - InternTable global_map_; + tvm::Map global_map_; // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ @@ -61,18 +60,17 @@ class EnvironmentNode : public RelayNode { std::unordered_map global_funcs); /*! Add an operator to the Enviroment. */ - void register_op(const Operator& op); - // void add(const Operator& op, bool update = false); + void register_op(const Op& op); + void add(const GlobalVar& var, const Function & func, bool update = false); + void try_add(const GlobalVar& var, const Function & func, bool update=false); + void update(const GlobalVar& var, const Function & func); + void remove(const GlobalVar& var); - // void try_add(const Item& item, bool update=false); - // void update(const Item& item); - // void remove(const GlobalId& id); - - // GlobalId global_id(const std::string& str); - Operator op(const std::string& str); + GlobalVar GetGlobalVar(const std::string& str); /*! \brief Lookup a global function by its name. */ - Function lookup(const GlobalVar& id); + Function Lookup(const GlobalVar& id); + Function Lookup(const std::string & s); /*! \brief Add a source fragment to the environment. */ // FileId add_source(std::string file_name, std::string source); diff --git a/include/tvm/relay/compiler/type_infer.h b/include/tvm/relay/compiler/type_infer.h index 6d07de1c29e8..c084fb7a109e 100644 --- a/include/tvm/relay/compiler/type_infer.h +++ b/include/tvm/relay/compiler/type_infer.h @@ -24,7 +24,7 @@ Expr Infer(const Environment & env, const Expr & e); /*! \brief Ensures that an operator is well-formed with respect * to Relay's type system. */ -Operator CheckOperator(const Environment & env, const Operator & op); +Op CheckOp(const Environment & env, const Op & op); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 922892e8a7a5..2067b90bd364 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -113,7 +113,7 @@ class ExprFunctor { virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const OperatorNode* op, + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw dmlc::Error(std::string("Do not have a default for ") + op->type_key()); @@ -133,7 +133,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); - RELAY_EXPR_FUNCTOR_DISPATCH(OperatorNode); + RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); return vtable; } }; diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 2039414b4238..d1e8a99dc374 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -58,7 +58,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctorVisitExpr(op->false_value, args...); } - void VisitExpr_(const OperatorNode* op, Args... args) override { return; } + void VisitExpr_(const OpNode* op, Args... args) override { return; } }; template @@ -72,8 +72,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor(op); } - Expr VisitExpr_(const OperatorNode* op, Args... args) override { - return GetRef(op); + Expr VisitExpr_(const OpNode* op, Args... args) override { + return GetRef(op); } Expr VisitExpr_(const TupleNode* op, Args... args) override { diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc index 735ef79ceb3a..7ce0785f4f8f 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/compiler/environment.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file environment.cc - * \brief Relay global environment. + * \brief The global environment in Relay. */ #include #include "tvm/relay/compiler/environment.h" @@ -24,34 +24,17 @@ Environment EnvironmentNode::make( return Environment(n); } -void EnvironmentNode::register_op(const Operator& op) { - this->operators.Insert(op->name, op); +GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { + auto global_id = global_map_.find(str); + if (global_id != global_map_.end()) { + return (*global_id).second; + } else { + auto id = GlobalVarNode::make(str); + this->global_map_.Set(str, id); + return id; + } } -// tvm::PackedFunc EnvironmentNode::jit_for(OperatorId id) { -// return this->lookup(id)->compiler; -// } - -// GlobalId EnvironmentNode::global_id(const std::string &str) { -// try { -// return global_map_.Lookup(str); -// } catch (const KeyNotFound &err) { -// GlobalId id = GlobalIdNode::make(str); -// global_map_.Insert(str, id); -// return id; -// } -// } - -// OperatorId EnvironmentNode::operator_id(const std::string &str) { -// try { -// return operator_map_.Lookup(str); -// } catch (const KeyNotFound &err) { -// OperatorId id = OperatorIdNode::make(str); -// operator_map_.Insert(str, id); -// return id; -// } -// } - // // Add a new item to the global environment // // throws an exception if the item already // // exists. @@ -79,15 +62,15 @@ void EnvironmentNode::register_op(const Operator& op) { // } else { // operators.insert({op->id, op}); // } -// } else if (const DefnNode *d = item.as()) { -// auto def = GetRef(d); +// } else if (const FunctionNode *d = item.as()) { +// auto def = GetRef(d); // auto type = def->type; // if (items.find(def->id) != items.end()) { // if (!update) { // throw dmlc::Error("already have definition for XXXX."); // } -// auto old_type = items[def->id].as()->type; +// auto old_type = items[def->id].as()->type; // if (!alpha_eq(type, old_type)) { // throw dmlc::Error( @@ -107,23 +90,18 @@ void EnvironmentNode::register_op(const Operator& op) { // void EnvironmentNode::remove(const GlobalId &id) { this->items.erase(id); } -// Defn EnvironmentNode::lookup(const GlobalId &id) { -// if (items.find(id) != items.end()) { -// return items.at(id); -// } else { -// throw EnvError(std::string("there is no definition of ") + id->name); -// } -// } - -Operator EnvironmentNode::op(const std::string & op_name) { - // FIX ME - return operators.Lookup(op_name); +Function EnvironmentNode::Lookup(const GlobalVar &var) { + if (items.find(var) != items.end()) { + return items.at(var); + } else { + throw Error(std::string("there is no definition of ") + var->name_hint); + } } -// Defn EnvironmentNode::lookup_global(const std::string &str) { -// GlobalId id = this->global_id(str); -// return this->lookup(id); -// } +Function EnvironmentNode::Lookup(const std::string &str) { + GlobalVar id = this->GetGlobalVar(str); + return this->Lookup(id); +} // inline FileId EnvironmentNode::add_source(std::string file_name, // std::string source) { @@ -163,26 +141,6 @@ Operator EnvironmentNode::op(const std::string & op_name) { // } // } -// Array EnvironmentNode::get_operators() { -// std::vector ops; -// for (auto pair : this->operators) { -// ops.push_back(pair.second); -// } -// return Array(ops); -// } - -// Array EnvironmentNode::get_defns() { -// std::vector defns; -// for (auto pair : this->items) { -// defns.push_back(pair.second); -// } -// return Array(defns); -// } - -// void EnvironmentNode::register_shape_ext(ShapeExtension ext) { -// this->shape_exts_.Insert(ext->name, ext); -// } - TVM_REGISTER_API("relay._make.Environment") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = EnvironmentNode::make({}); From cedf57244359cb017a075413ed08a6be5ad7dab4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Aug 2018 17:14:05 -0700 Subject: [PATCH 33/88] Change over to new Node construction --- python/tvm/relay/__init__.py | 16 ++++- python/tvm/relay/_make.pyi | 91 -------------------------- python/tvm/relay/base.py | 4 ++ python/tvm/relay/expr.py | 29 ++++++++ python/tvm/relay/ir_builder.py | 4 ++ python/tvm/relay/make.py | 75 --------------------- python/tvm/relay/op.py | 2 +- python/tvm/relay/type.py | 59 ++++++++++++++++- tests/python/relay/test_ir_nodes.py | 59 +++++++++-------- tests/python/relay/test_typechecker.py | 4 +- 10 files changed, 143 insertions(+), 200 deletions(-) delete mode 100644 python/tvm/relay/_make.pyi delete mode 100644 python/tvm/relay/make.py diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index a9446ebed979..037d71854689 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,9 +1,12 @@ """The Relay IR namespace containing the IR definition and compiler.""" from . import base from . import type as tpe -from . import make +from . import expr from . import op +# Span +Span = base.Span + # Type Type = tpe.Type TensorType = tpe.TensorType @@ -11,3 +14,14 @@ TypeParam = tpe.TypeParam TypeConstraint = tpe.TypeConstraint FuncType = tpe.FuncType + +# Expr +Constant = expr.Constant +Tuple = expr.Tuple +LocalVar = expr.LocalVar +GlobalVar = expr.GlobalVar +Param = expr.Param +Function = expr.Function +Call = expr.Call +Let = expr.Let +If = expr.If diff --git a/python/tvm/relay/_make.pyi b/python/tvm/relay/_make.pyi deleted file mode 100644 index d94857916319..000000000000 --- a/python/tvm/relay/_make.pyi +++ /dev/null @@ -1,91 +0,0 @@ -# from typing import Dict, List, Any, Callable, TypeVar as PyTypeVar -# import nnvm.relay.ir as ir -# import nnvm.relay.env as env -# import ctypes - -# # Environment -# def Environment(items: Dict[ir.GlobalId, ir.Item]) -> env.Environment: ... - -# # Items TODO(@jroesch) Correct Anys to the right type. -# def Operator(id: ir.OperatorId, tvm_name: str, ty: ir.Type, compiler: Any, fwd_mode: Any, rev_mode: Any) -> ir.Operator: ... -# def Defn(id: ir.GlobalId, ty: ir.Type, body: ir.Function) -> ir.Defn: ... - -# # Types -# def IntType(bits: int, lanes: int) -> ir.Type: ... -# def UIntType(bits: int, lanes: int) -> ir.Type: ... -# def FloatType(bits: int, lanes: int) -> ir.Type: ... -# def BoolType(lanes: int) -> ir.Type: ... -# def TupleType(fields: List[ir.Type]) -> ir.Type: ... -# def TensorType(dtype: ir.Type, shape: ir.Type) -> ir.Type: ... -# def TypeParam(name: str, kind: ir.Kind) -> ir.Type: ... -# def TypeQuantifier(id: ir.TypeId, body: ir.Type) -> ir.Type: ... -# def TypeArrow(left: ir.Type, right: ir.Type) -> ir.Type: ... -# def TypeVar(kind: ir.Kind) -> ir.Type: ... -# def PlaceholderType() -> ir.Type: ... -# def ShapeSeq(shapes: List[ir.Type]) -> ir.ShapeSeq: ... -# def ShapeSingleton(value: int) -> ir.ShapeSingleton: ... -# def ShapeAttr(id: ir.StringLit) -> ir.ShapeAttr: ... -# def ShapeProjection(shape: ir.Type, value: int) -> ir.ShapeProjection: ... -# def ShapeBinaryOp(op: ir.ShapeOp, left: ir.Type, right: ir.Type) -> ir.ShapeBinaryOp: ... -# def ShapeBroadcast(left: ir.Type, right: ir.Type) -> ir.ShapeBroadcast: ... -# def ShapeExtension(name: str, eval: Any) -> ir.ShapeExtension: ... -# def TypeCall(func: ir.Type, args: List[ir.Type]) -> ir.TypeCall: ... -# def RefType(data_type: ir.Type) -> ir.RefType: ... - -# # Expressions -# def Param(id: ir.LocalId, type: ir.Type) -> ir.Param: ... -# def Function(ty_params: List[ir.TypeId], params: List[ir.Param], ret_type: ir.Type, body: ir.Expr) -> ir.Function: ... -# def LocalId(name: str) -> ir.Expr: ... -# def GlobalId(name: str) -> ir.Expr: ... -# def OperatorId(name: str) -> ir.Expr: ... -# def Let(id: ir.LocalId, ty: ir.Type, value: ir.Expr, body: ir.Expr) -> ir.Expr: ... -# def IntLit(value: int) -> ir.IntLit: ... -# def FloatLit(value: float) -> ir.FloatLit: ... -# def TensorLit(value: List[ir.Expr]) -> ir.TensorLit: ... -# def Tuple(fields: List[ir.Expr]) -> ir.Expr: ... -# def BoolLit(value: bool) -> ir.BoolLit: ... -# def StringLit(value: str) -> ir.StringLit: ... -# def Attributes(attrs: Dict[str, ir.Expr]) -> ir.Attributes: ... -# def Call(func: ir.Expr, args: List[ir.Expr], attrs: ir.Attributes) -> ir.Call: ... -# def UnaryOp(op: ir.UOp, arg: ir.Expr) -> ir.Expr: ... -# def BinaryOp(op: ir.BOp, left: ir.Expr, right: ir.Expr) -> ir.Expr: ... -# def Projection(tuple: ir.Expr, field : int) -> ir.Expr: ... -# def Gradient(node: ir.Expr) -> ir.Expr: ... -# def Cast(target: ir.Type, node: ir.Expr) -> ir.Expr: ... -# def Debug(node: ir.Expr) -> ir.Expr: ... -# def Zero(type: ir.Type) -> ir.Expr: ... -# def If(guard: ir.Expr, true_branch: ir.Expr, false_branch: ir.Expr) -> ir.Expr: ... -# def Ref(value: ir.Expr) -> ir.Expr: ... -# def ReadRef(ref: ir.Expr) -> ir.Expr: ... -# def WriteRef(ref: ir.Expr, value: ir.Expr) -> ir.Expr: ... - -# # Values -# def IntValue(value: int) -> ir.TensorValue: ... -# def FloatValue(value: float) -> ir.TensorValue: ... -# def BoolValue(value: bool) -> ir.TensorValue: ... -# def TensorValue(handle: ctypes.c_void_p) -> ir.TensorValue: ... -# def Closure(env: Dict[ir.LocalId, ir.Value], fn: ir.Function) -> ir.Closure: ... - -# # Error Reporting -# def Span(file_id: ir.FileId, lineno: int, col_offset: int) -> ir.NodeBase: ... -# def FileId(file_id: int) -> ir.FileId: ... - -# # Utils -# def _alpha_eq(e1: ir.Expr, e2: ir.Expr) -> bool: ... -# def _type_alpha_eq(e1: ir.Type, e2: ir.Type) -> bool: ... -# def _expr_set_span(e: ir.Expr, sp: ir.Span) -> None: ... -# def _type_set_span(t: ir.Type, sp: ir.Span) -> None: ... -# def _item_set_span(t: ir.Item, sp: ir.Span) -> None: ... -# def Node_hash(n: ir.Node) -> int: ... -# def Operator_is_generic(op: ir.Operator) -> bool: ... - -# # FIXME -# def UnionFind() -> Any: ... -# def TypeUnifier() -> Any: ... - -# T = PyTypeVar('T') -# U = PyTypeVar('U') -# PassFunc = Callable[[env.Environment], Callable[[T], U]] - -# # Passes -# def ItemPass(name: str, pass_func: PassFunc[ir.Item, ir.Item]) -> ir.ItemPass: ... diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 687ba53ac005..ee818617f629 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs from typing import Union from .._ffi.node import NodeBase, register_node as _register_tvm_node +from . import _make NodeBase = NodeBase @@ -25,3 +26,6 @@ class Span(NodeBase): source: "FileSource" lineno: int col_offset: int + + def __init__(self, source, lineno, col_offset): + self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index c17a69dd0dc9..7f5dcbd0beb5 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -7,6 +7,7 @@ from .type import Type, TypeParam from tvm import expr from ._type_infer import _get_checked_type +from . import _make class Expr(NodeBase): """The base type for all Relay exprressions.""" @@ -19,6 +20,9 @@ class Constant(Expr): """ data: tvm.nd.NDArray + def __init__(self, data: tvm.nd.NDArray) -> None: + self.__init_handle_by_constructor__(_make.Constant, data) + @register_relay_node class Tuple(Expr): """A hetereogenous sequence of values. @@ -26,16 +30,26 @@ class Tuple(Expr): """ fields: List[Expr] + def __init__(self, fields: List[Expr]) -> None: + self.__init_handle_by_constructor__(_make.Tuple, fields) + + @register_relay_node class LocalVar(Expr): """A local variable in Relay.""" name_hint: str + def __init__(self, name_hint: str) -> None: + self.__init_handle_by_constructor__(_make.LocalVar, name_hint) + @register_relay_node class GlobalVar(Expr): """A global variable in Relay.""" name_hint: str + def __init__(self, name_hint: str) -> None: + self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) + @register_relay_node class Param(Expr): """A function type in Relay, see tvm/relay/type.h for more details. @@ -43,6 +57,10 @@ class Param(Expr): var: LocalVar type: Type + def __init__(self, var: LocalVar, type: Type) -> None: + self.__init_handle_by_constructor__(_make.Param, var, type) + + @register_relay_node class Function(Expr): type_params: List[TypeParam] @@ -50,11 +68,17 @@ class Function(Expr): ret_type: Type body: Expr + def __init__(self, params: List[Param], ret_type: Type, body: Expr, type_params: List[TypeParam]=[]) -> None: + self.__init_handle_by_constructor__(_make.Function, params, ret_type, body, type_params) + class Call(Expr): op: Expr args: List[Expr] # todo(@jroesch): add attrs + def __init__(self, op: Expr, args: List[Expr], attrs, ty_args) -> None: + self.__init_handle_by_constructor__(_make.Call, op, args, attrs, ty_args) + @register_relay_node class Let(Expr): var: LocalVar @@ -62,6 +86,9 @@ class Let(Expr): body: Expr value_type: Type # should be type nanotation + def __init__(self, var: LocalVar, value: Expr, body: Expr, value_type: Type) -> None: + self.__init_handle_by_constructor__(_make.Let, var, value, body, value_type) + @register_relay_node class If(Expr): cond: Expr @@ -69,3 +96,5 @@ class If(Expr): false_value: Expr span: Span + def __init__(self, cond: Expr, true_value: Expr, false_value: Expr) -> None: + self.__init_handle_by_constructor__(_make.If, cond, true_value, false_value) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 35bc31265987..f24e7baa1483 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -4,6 +4,7 @@ from . import type as ty from . import expr from . import make as mk +from . import op class ExprBuilder(): def __init__(self, expr): @@ -142,6 +143,9 @@ def get(self): return _mk_let(bindings, self.ret_value) +def op(name): + return op._create_op(name) + def bool_dtype(): return 'uint1' diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py deleted file mode 100644 index 8cad18c11c46..000000000000 --- a/python/tvm/relay/make.py +++ /dev/null @@ -1,75 +0,0 @@ -from . import _make -from . import ir - -This module includes MyPy type signatures for all of the -exposed modules. -""" -from __future__ import absolute_import as _abs -from .._ffi.function import _init_api - -# Base Constructors -Span = _make.Span - -# Environment -Environment = _make.Environment - -# Type Constructors -TensorType = _make.TensorType -TypeParam = _make.TypeParam -FuncType = _make.FuncType - -# Types -def IntType(bits: int, lanes: int=1) -> ir.Type: - """Constructs a integer base type. - - :param bits: The bit width of the integer type. - :param lanes: The number of vector elements for this datatype. - - """ - return _make.IntType(bits, lanes) - - -def UIntType(bits: int, lanes: int=1) -> ir.Type: - """Constructs a unsigned integer base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.UIntType(bits, lanes) - - -def FloatType(bits: int, lanes: int=1) -> ir.Type: - """Constructs a floating point base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.FloatType(bits, lanes) - - -def BoolType(lanes: int =1) -> ir.Type: - """Constructs a boolean base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.BoolType(lanes) - -# Expr Constructors -Constant = _make.Constant -Tuple = _make.Tuple -LocalVar = _make.LocalVar -GlobalVar = _make.GlobalVar -Param = _make.Param -Function = _make.Function -Call = _make.Call -Let = _make.Let -If = _make.If -IncompleteType = _make.IncompleteType - -# Unifier -UnionFind = _make.UnionFind -TypeUnifier = _make.TypeUnifier - -# Utility Functionality @TODO(jroesch): move to another location -_type_alpha_eq = _make._type_alpha_eq diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py index 373d04f984a3..dae498b66c12 100644 --- a/python/tvm/relay/op.py +++ b/python/tvm/relay/op.py @@ -4,7 +4,7 @@ import sys from .._ffi.function import _init_api from .._ffi.node import convert_to_node -from . import make as _make +from . import _make from ..make import node as _make_node def _create_op(op_name): diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index a04089792282..c7b8964c20e8 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -5,7 +5,7 @@ from .base import Span, NodeBase, register_relay_node from tvm import expr # TODO(@jroesch): move me -from ._make import _type_alpha_eq +from . import _make class Type(NodeBase): """The base type for all Relay types.""" @@ -14,7 +14,7 @@ def __eq__(self, other) -> bool: """Compares two Relay types for structural equivalence using alpha equivalence. """ - return bool(_type_alpha_eq(self, other)) + return bool(_make._type_alpha_eq(self, other)) def __ne__(self, other) -> bool: return not self.__eq__(other) @@ -31,6 +31,9 @@ class TensorType(Type): shape: List[expr.Expr] span: Span + def __init__(self, dtype: str, shape: List[expr.Expr]) -> None: + self.__init_handle_by_constructor__(_make.TensorType,dtype, shape) + class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, base type, type, or dimension. @@ -49,6 +52,9 @@ class TypeParam(Type): kind: Kind span: Span + def __init__(self, var: expr.Var, kind: Kind) -> None: + self.__init_handle_by_constructor__(_make.TypeParam, var, kind) + @register_relay_node class TypeConstraint(Type): """Abstract class representing a type constraint.""" @@ -64,7 +70,54 @@ class FuncType(Type): ret_type: Type span: Span + def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[TypeParam], type_constraints: List[TypeConstraint]) -> None: + self.__init_handle_by_constructor__(_make.FuncType, arg_types, ret_type, type_params, type_constraints) + +@register_relay_node +class TypeCall(Type): + def __init__() -> None: + pass + + @register_relay_node class IncompleteType(Type): """An incomplete type.""" - pass + + def __init__(self, kind: Kind) -> None: + self.__init_handle_by_constructor__(_make.IncompleteType, kind) + +def IntType(bits: int, lanes: int=1) -> Type: + """Constructs a integer base type. + + :param bits: The bit width of the integer type. + :param lanes: The number of vector elements for this datatype. + + """ + return _make.IntType(bits, lanes) + + +def UIntType(bits: int, lanes: int=1) -> Type: + """Constructs a unsigned integer base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.UIntType(bits, lanes) + + +def FloatType(bits: int, lanes: int=1) -> Type: + """Constructs a floating point base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.FloatType(bits, lanes) + + +def BoolType(lanes: int =1) -> Type: + """Constructs a boolean base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.BoolType(lanes) diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 26fe06109513..676aa347950b 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -1,14 +1,11 @@ """ test ir""" import tvm from tvm import relay -import tvm.relay.make as mk -from tvm import expr +from tvm.expr import * # Span - - def test_span() -> None: - span = mk.Span(None, 1, 1) + span = relay.Span(None, 1, 1) assert span.source == None assert span.lineno == 1 assert span.col_offset == 1 @@ -19,11 +16,10 @@ def test_span() -> None: # Types - def test_tensor_type() -> None: shape = tvm.convert([1, 2, 3]) dtype = 'float32' - tt = mk.TensorType(shape, dtype) + tt = relay.TensorType(shape, dtype) assert tt.dtype == dtype assert tt.shape == shape assert tt.span == None @@ -31,7 +27,7 @@ def test_tensor_type() -> None: def test_type_param() -> None: - tp = mk.TypeParam('name', relay.Kind.Shape) + tp = relay.TypeParam('name', relay.Kind.Shape) tp.kind == relay.Kind.Shape tp.span # TODO allow us to set span str(tp) @@ -42,7 +38,7 @@ def test_func_type() -> None: type_constraints = tvm.convert([]) # TODO: fill me in arg_types = tvm.convert([]) ret_type = None - tf = mk.FuncType(arg_types, ret_type, type_params, type_constraints) + tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) assert tf.type_params == type_params assert tf.type_constraints == type_constraints assert tf.arg_types == arg_types @@ -54,7 +50,7 @@ def test_func_type() -> None: def test_constant() -> None: arr = tvm.nd.array(10) - const = mk.Constant(arr) + const = relay.Constant(arr) assert const.data == arr assert const.span == None str(const) @@ -62,7 +58,7 @@ def test_constant() -> None: def test_tuple() -> None: fields = tvm.convert([]) - tup = mk.Tuple(fields) + tup = relay.Tuple(fields) assert tup.fields == fields assert tup.span == None str(tup) @@ -70,7 +66,7 @@ def test_tuple() -> None: def test_local_var() -> None: name_hint = 's' - lv = mk.LocalVar(name_hint) + lv = relay.LocalVar(name_hint) lv.name_hint == name_hint # assert lv.span == None todo(@jroesch): what do we do about spans str(lv) @@ -78,16 +74,16 @@ def test_local_var() -> None: def test_global_var() -> None: name_hint = 'g' - gv = mk.GlobalVar(name_hint) + gv = relay.GlobalVar(name_hint) gv.name_hint == name_hint # assert lv.span == None todo(@jroesch): what do we do about spans str(gv) def test_param() -> None: - lv = mk.LocalVar('x') + lv = relay.LocalVar('x') ty = None - param = mk.Param(lv, ty) + param = relay.Param(lv, ty) assert param.var == lv assert param.type == ty assert param.span == None @@ -96,11 +92,11 @@ def test_param() -> None: def test_function() -> None: param_names = ['a', 'b', 'c', 'd'] - params = tvm.convert([mk.Param(mk.LocalVar(n), None) for n in param_names]) + params = tvm.convert([relay.Param(relay.LocalVar(n), None) for n in param_names]) ret_type = None body = None type_params = tvm.convert([]) - fn = mk.Function(params, ret_type, body, type_params) + fn = relay.Function(params, ret_type, body, type_params) assert fn.params == params assert fn.body == body assert fn.type_params == type_params @@ -109,10 +105,10 @@ def test_function() -> None: def test_call() -> None: - op = mk.LocalVar('f') + op = relay.LocalVar('f') arg_names = ['a', 'b', 'c', 'd'] - args = tvm.convert([mk.LocalVar(n) for n in arg_names]) - call = mk.Call(op, args, None, None) + args = tvm.convert([relay.LocalVar(n) for n in arg_names]) + call = relay.Call(op, args, None, None) assert call.op == op assert call.args == args assert call.span == None @@ -120,13 +116,13 @@ def test_call() -> None: def test_let() -> None: - lv = mk.LocalVar('x') + lv = relay.LocalVar('x') ty = None arr = tvm.nd.array(10) - value = mk.Constant(arr) + value = relay.Constant(arr) # I would prefer that the order of arguments # matches syntax let x : t = v in b - let = mk.Let(lv, value, lv, ty) + let = relay.Let(lv, value, lv, ty) assert let.var == lv assert let.value == value assert let.value_type == ty @@ -136,10 +132,10 @@ def test_let() -> None: def test_if() -> None: - cond = mk.LocalVar('cond') - left = mk.LocalVar('left') - right = mk.LocalVar('right') - ife = mk.If(cond, left, right) + cond = relay.LocalVar('cond') + left = relay.LocalVar('left') + right = relay.LocalVar('right') + ife = relay.If(cond, left, right) assert ife.cond == cond assert ife.true_value == left assert ife.false_value == right @@ -152,3 +148,12 @@ def test_if() -> None: test_tensor_type() test_type_param() test_func_type() + test_constant() + test_tuple() + test_local_var() + test_global_var() + test_param() + test_function() + test_call() + test_let() + test_if() diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index c6bc1a05ebe9..bf172eb07935 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -3,7 +3,7 @@ """ import tvm.relay.make as mk from tvm.relay.type_infer import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type +from tvm.relay.ir_builder import IRBuilder, float_type, op def has_type(expr, typ): env = mk.Environment({}) @@ -23,7 +23,7 @@ def test_monomorphic_let(): def test_single_op(): "Program: fn (x : int32) { let t1 = f(x); t1 }" b = IRBuilder() - f = b.op('f') + f = op('log') with b.function(('x', float_type())) as func: x, = func.param_ids() t1 = b.let('t1', f(x)) From c32bbe7583b9ed2058a12b7803247b0546a5f47b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Aug 2018 17:18:37 -0700 Subject: [PATCH 34/88] Basic tests working --- python/tvm/relay/ir_builder.py | 5 +- tests/python/relay/test_alpha_eq.py | 1 - tests/python/relay/test_typechecker.py | 1 - tests/python/relay/test_unifier.py | 163 ++++++++++++------------- 4 files changed, 83 insertions(+), 87 deletions(-) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index f24e7baa1483..af83c9948be2 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -3,8 +3,7 @@ import tvm from . import type as ty from . import expr -from . import make as mk -from . import op +from . import op as _op class ExprBuilder(): def __init__(self, expr): @@ -144,7 +143,7 @@ def get(self): return _mk_let(bindings, self.ret_value) def op(name): - return op._create_op(name) + return _op._create_op(name) def bool_dtype(): return 'uint1' diff --git a/tests/python/relay/test_alpha_eq.py b/tests/python/relay/test_alpha_eq.py index e4fbbcca93ce..6c0e7779eae6 100644 --- a/tests/python/relay/test_alpha_eq.py +++ b/tests/python/relay/test_alpha_eq.py @@ -1,5 +1,4 @@ """Test alpha-equivalence of expressions and types.""" -from tvm.relay import make as mk # from relay.ir import alpha_eq, ShapeOp, Kind # from relay.typing import TYPE_DEFAULTS # from relay import ir diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index bf172eb07935..6a16aadcb002 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -1,7 +1,6 @@ """Test that type checker correcly computes types for expressions. """ -import tvm.relay.make as mk from tvm.relay.type_infer import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, op diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index 21889faa51ee..c45e6ac4f732 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -3,17 +3,16 @@ between incomplete types. """ import tvm -from tvm.relay import ir +from tvm import relay from tvm.relay.unifier import UnionFind, TypeUnifier from tvm.relay.ir_builder import bool_type, uint_type, int_type, float_type, func_type from tvm.relay import ir_builder as build -import tvm.relay.make as mk def test_insert_and_find(): - uf = mk.UnionFind() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + uf = relay.UnionFind() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) uf.insert(v1) uf.insert(v2) assert uf.find(v1) == v1 @@ -21,9 +20,9 @@ def test_insert_and_find(): def test_insert_error(): - uf = mk.UnionFind() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + uf = relay.UnionFind() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) uf.insert(v1) try: uf.find(v2) @@ -33,10 +32,10 @@ def test_insert_error(): def test_unify(): - uf = mk.UnionFind() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) - v3 = mk.IncompleteType(ir.Kind.Type) + uf = relay.UnionFind() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) + v3 = relay.IncompleteType(ir.Kind.Type) uf.insert(v1) uf.insert(v2) uf.insert(v3) @@ -56,8 +55,8 @@ def test_unify(): def test_unify_multiple_levels(): - uf = mk.UnionFind() - v = [mk.IncompleteType(ir.Kind.Type) for _ in range(9)] + uf = relay.UnionFind() + v = [relay.IncompleteType(ir.Kind.Type) for _ in range(9)] for var in v: uf.insert(var) uf.unify(v[0], v[1]) @@ -94,7 +93,7 @@ def test_unify_multiple_levels(): def unify_types(t1, t2): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() return unifier.unify(t1, t2) # TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work @@ -136,8 +135,8 @@ def test_unify_concrete_func_type(): def test_unify_func_type_with_holes(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.unify(v1, bool_type()) arr1 = func_type([int_type()], bool_type()) @@ -145,7 +144,7 @@ def test_unify_func_type_with_holes(): unified = unifier.unify(arr1, arr2) assert unified == arr1 - v2 = mk.IncompleteType(ir.Kind.BaseType) + v2 = relay.IncompleteType(ir.Kind.BaseType) unifier.insert(v2) unifier.unify(v2, int_type()) arr3 = func_type([v2], bool_type()) @@ -179,10 +178,10 @@ def test_reject_incompatible_func_types(): def test_unify_typevars_with_each_other(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) - v3 = mk.IncompleteType(ir.Kind.Type) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) + v3 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.insert(v3) @@ -194,10 +193,10 @@ def test_unify_typevars_with_each_other(): def test_unify_typevars_with_basetype(): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() bt = bool_type() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unified1 = unifier.unify(v1, bt) @@ -207,10 +206,10 @@ def test_unify_typevars_with_basetype(): def test_unify_compatible_typevars(): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() bt = bool_type() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, bt) @@ -221,9 +220,9 @@ def test_unify_compatible_typevars(): assert unified == bt # def test_unify_incompatible_typevars(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) -# v2 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) +# v2 = relay.IncompleteType(ir.Kind.Type) # bt = bool_type() # tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) # unifier.insert(v1) @@ -238,16 +237,16 @@ def test_unify_compatible_typevars(): # return # def test_unify_typevar_with_quantifier(): -# unifier = mk.TypeUnifier() +# unifier = relay.TypeUnifier() # tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) -# v1 = mk.IncompleteType(ir.Kind.BaseType) +# v1 = relay.IncompleteType(ir.Kind.BaseType) # unifier.insert(v1) # unified = unifier.unify(v1, tq) # assert unified == tq # def test_unify_typevars_inside_concrete_quantifier(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.BaseType) # unifier.insert(v1) # tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) # tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) @@ -258,8 +257,8 @@ def test_unify_compatible_typevars(): def test_unify_concrete_tensors(): bt = build.bool_dtype() shape = tvm.convert([1, 2, 3]) - tt1 = mk.TensorType(shape, bt) - tt2 = mk.TensorType(shape, bt) + tt1 = relay.TensorType(shape, bt) + tt2 = relay.TensorType(shape, bt) unified = unify_types(tt1, tt2) assert unified == tt1 @@ -268,8 +267,8 @@ def test_unify_tensor_shape_reject(): bt = build.bool_dtype() shape1 = tvm.convert([1, 2, 3]) shape2 = tvm.convert([2, 3, 4]) - tt1 = mk.TensorType(shape1, bt) - tt2 = mk.TensorType(shape2, bt) + tt1 = relay.TensorType(shape1, bt) + tt2 = relay.TensorType(shape2, bt) try: unify_types(tt1, tt2) assert False @@ -281,8 +280,8 @@ def test_unify_tensor_dtype_reject(): bt1 = build.bool_dtype() bt2 = build.int_dtype() shape = tvm.convert([1, 2, 3]) - tt1 = mk.TensorType(shape, bt1) - tt2 = mk.TensorType(shape, bt2) + tt1 = relay.TensorType(shape, bt1) + tt2 = relay.TensorType(shape, bt2) try: unify_types(tt1, tt2) assert False @@ -292,15 +291,15 @@ def test_unify_tensor_dtype_reject(): # def test_unify_quantified_tensors(): # x = TypeParam("x", ir.type.Kind.Shape) # y = TypeParam("y", ir.type.Kind.Shape) -# tq1 = TypeQuantifier(x, mk.TensorType(bool_type(), x)) -# tq2 = TypeQuantifier(y, mk.TensorType(bool_type(), y)) +# tq1 = TypeQuantifier(x, relay.TensorType(bool_type(), x)) +# tq2 = TypeQuantifier(y, relay.TensorType(bool_type(), y)) # unified = unify_types(tq1, tq2) # assert unified == tq1 # a = TypeParam("a", ir.type.Kind.BaseType) # b = TypeParam("b", ir.type.Kind.BaseType) -# tq3 = TypeQuantifier(a, mk.TensorType(a, make_shape([1, 2, 3]))) -# tq4 = TypeQuantifier(b, mk.TensorType(b, make_shape([1, 2, 3]))) +# tq3 = TypeQuantifier(a, relay.TensorType(a, make_shape([1, 2, 3]))) +# tq4 = TypeQuantifier(b, relay.TensorType(b, make_shape([1, 2, 3]))) # unified = unify_types(tq3, tq4) # assert unified == tq3 @@ -335,8 +334,8 @@ def test_unify_tensor_dtype_reject(): # return # def test_unify_products_typevar(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.BaseType) # bt = bool_type() # pt1 = TupleType([bt, bt]) # pt2 = TupleType([v1, bt]) @@ -354,14 +353,14 @@ def test_unify_tensor_dtype_reject(): def test_subst_basetype(): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() bt = bool_type() assert bt == unifier.subst(bt) def test_subst_simple_hole(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.BaseType) bt = bool_type() unifier.insert(v1) unifier.unify(v1, bt) @@ -369,9 +368,9 @@ def test_subst_simple_hole(): def test_subst_typevar_for_typevar(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) @@ -380,9 +379,9 @@ def test_subst_typevar_for_typevar(): def test_subst_typevar_for_typevar_comm(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) @@ -391,15 +390,15 @@ def test_subst_typevar_for_typevar_comm(): def test_subst_concrete_arrow(): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() arr1 = func_type([int_type()], int_type()) assert unifier.subst(arr1) == arr1 def test_subst_arrow_with_holes(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.BaseType) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.BaseType) + v2 = relay.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, int_type()) @@ -409,17 +408,17 @@ def test_subst_arrow_with_holes(): assert unifier.subst(arr1) == arr2 # def test_subst_concrete_quantifier(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.BaseType) # tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) # unifier.insert(v1) # unifier.unify(v1, tq) # assert unifier.subst(v1) == tq # def test_subst_quantifier_with_holes(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) -# v2 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) +# v2 = relay.IncompleteType(ir.Kind.Type) # tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) # intty = int_type() # tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) @@ -431,16 +430,16 @@ def test_subst_arrow_with_holes(): def test_subst_concrete_tensor(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) - tt = mk.TensorType(tvm.convert([1, 2, 3]), 'uint1') + tt = relay.TensorType(tvm.convert([1, 2, 3]), 'uint1') unifier.unify(v1, tt) assert unifier.subst(v1) == tt # def test_subst_concrete_product(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) # unifier.insert(v1) # bt = bool_type() # pt = TupleType([bt, bt]) @@ -448,16 +447,16 @@ def test_subst_concrete_tensor(): # assert unifier.subst(v1) == pt # def test_subst_product_with_holes(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) -# v2 = mk.IncompleteType(ir.Kind.Type) -# v3 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) +# v2 = relay.IncompleteType(ir.Kind.Type) +# v3 = relay.IncompleteType(ir.Kind.Type) # unifier.insert(v1) # unifier.insert(v2) # unifier.insert(v3) -# tt1 = mk.TensorType(int_type(), tvm.convert([])) -# tt2 = mk.TensorType(FloatType(32), tvm.convert([])) +# tt1 = relay.TensorType(int_type(), tvm.convert([])) +# tt2 = relay.TensorType(FloatType(32), tvm.convert([])) # pt1 = TupleType([tt1, v2, v3]) # unifier.unify(v2, tt2) # unifier.unify(v3, v2) @@ -466,13 +465,13 @@ def test_subst_concrete_tensor(): # assert unifier.subst(v1) == pt2 # def test_subst_concrete_ref(): -# unifier = mk.TypeUnifier() +# unifier = relay.TypeUnifier() # rt = RefType(bool_type()) # assert unifier.subst(rt) == rt # def test_subst_ref_with_hole(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) # unifier.insert(v1) # unifier.unify(v1, bool_type()) @@ -481,9 +480,9 @@ def test_subst_concrete_tensor(): # assert unifier.subst(rt1) == rt2 # def test_typevar_on_lhs(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.BaseType) -# v2 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.BaseType) +# v2 = relay.IncompleteType(ir.Kind.Type) # bt = bool_type() # tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) # unifier.insert(v1) From 5d0242fc01471e2701a54bd8483c2df10bd9d787 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Aug 2018 17:28:36 -0700 Subject: [PATCH 35/88] Remove unifier from Python interface --- python/tvm/relay/_unifier.py | 5 - python/tvm/relay/_unifier.pyi | 12 - python/tvm/relay/unifier.py | 61 ---- src/relay/compiler/unifier.cc | 80 ----- tests/python/relay/test_unifier.py | 495 ----------------------------- 5 files changed, 653 deletions(-) delete mode 100644 python/tvm/relay/_unifier.py delete mode 100644 python/tvm/relay/_unifier.pyi delete mode 100644 python/tvm/relay/unifier.py delete mode 100644 tests/python/relay/test_unifier.py diff --git a/python/tvm/relay/_unifier.py b/python/tvm/relay/_unifier.py deleted file mode 100644 index 41f5fe374b3e..000000000000 --- a/python/tvm/relay/_unifier.py +++ /dev/null @@ -1,5 +0,0 @@ -"""FFI functions for the Unifier.""" - -from tvm._ffi.function import _init_api - -_init_api("relay._unifier", __name__) diff --git a/python/tvm/relay/_unifier.pyi b/python/tvm/relay/_unifier.pyi deleted file mode 100644 index 6ecd309250a6..000000000000 --- a/python/tvm/relay/_unifier.pyi +++ /dev/null @@ -1,12 +0,0 @@ -from tvm.relay.ir import NodeBase - -class UnionFind(NodeBase): ... -class TypeUnifier(NodeBase): ... - -def UnionFind_insert(self: UnionFind, var: ir.IncompleteType) -> None: ... -def UnionFind_unify(self: UnionFind, var1: ir.IncompleteType, var2: ir.IncompleteType) -> None: ... -def UnionFind_find(self: UnionFind, var: ir.IncompleteType) -> ir.Type: ... - -def TypeUnifier_insert(self: TypeUnifier, var: ir.IncompleteType) -> None: ... -def TypeUnifier_unify(self, type1: ir.Type, type2: ir.Type) -> ir.Type: ... -def TypeUnifier_subst(self, type1: ir.Type) -> ir.Type: ... diff --git a/python/tvm/relay/unifier.py b/python/tvm/relay/unifier.py deleted file mode 100644 index cb818de19c1d..000000000000 --- a/python/tvm/relay/unifier.py +++ /dev/null @@ -1,61 +0,0 @@ -"""The Python interface to Relay's UnionFind and TypeUnifier.""" - -from typing import Dict -from .ir import register_relay_node, NodeBase -from . import ir -from . import _unifier - -@register_relay_node -class UnionFind(NodeBase): - """Python API for UnionFind. - - The UnionFind maintains equality classes of type variables, the - representative of an equality class may be a type (which can) - contain type variables. The TypeUnifier uses this to build a - unification procedure between types. - """ - uf_map: Dict[ir.IncompleteType, ir.IncompleteType] - - def insert(self, var: ir.IncompleteType) -> None: - """Insert a type variable into the union find. - - :param: var: The variable to be inserted. - """ - return _unifier.UnionFind_insert(self, var) - - def unify(self, var: ir.IncompleteType, typ: ir.Type) -> None: - """Unify a type variable with an arbitrary type. - - :param: var: A type variable to be unified. - :param: typ: The type to be unified with. - """ - return _unifier.UnionFind_unify(self, var, typ) - - def find(self, var: ir.IncompleteType) -> ir.IncompleteType: - """Find the representative element of the type var. - - :param: var: The variable to lookup in the union find. - """ - return _unifier.UnionFind_find(self, var) - -@register_relay_node -class TypeUnifier(NodeBase): - """Python API for the TypeUnifier.""" - #pylint: disable=invalid-name - uf: UnionFind - eq_map: Dict[ir.TypeParam, ir.TypeParam] - - def insert(self, var: ir.IncompleteType) -> None: - return _unifier.TypeUnifier_insert(self, var) - - def unify(self, type1: ir.Type, type2: ir.Type) -> ir.Type: - """Unify two types producing the unified type as a result. - - :param: type1: The first type to be unified. - :param: type2: The second type to be unified. - :returns: The unified type. - """ - return _unifier.TypeUnifier_unify(self, type1, type2) - - def subst(self, type1: ir.Type) -> ir.Type: - return _unifier.TypeUnifier_subst(self, type1) diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc index ff46e8e863d1..5c0fbcf3ec71 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/compiler/unifier.cc @@ -369,85 +369,5 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { } - -TVM_REGISTER_API("relay._make.TypeUnifier") - .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() < 3) { - *ret = TypeUnifierNode::make(UnionFindNode::make({})); - } else { - *ret = TypeUnifierNode::make(args[0]); - } - }); - -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const TypeUnifierNode *node, - tvm::IRPrinter *p) { - p->stream << "TypeUnifierNode(" << node->uf << ")"; - }); - -TVM_REGISTER_API("relay._unifier.UnionFind_insert") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - UnionFind uf = args[0]; - uf->insert(args[1]); - } catch (std::exception &e) { - throw UnionFindError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.UnionFind_unify") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - UnionFind uf = args[0]; - uf->unify(args[1], args[2]); - } catch (std::exception &e) { - throw UnionFindError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.UnionFind_find") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - UnionFind uf = args[0]; - *ret = uf->find(args[1]); - } catch (std::exception &e) { - throw UnionFindError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.TypeUnifier_insert") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - TypeUnifier unifier = args[0]; - IncompleteType var = args[1]; - unifier->insert(var); - } catch (std::exception &e) { - throw UnificationError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.TypeUnifier_unify") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - TypeUnifier unifier = args[0]; - Type t1 = args[1]; - Type t2 = args[2]; - *ret = unifier->unify(t1, t2); - } catch (std::exception &e) { - throw UnificationError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.TypeUnifier_subst") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - TypeUnifier unifier = args[0]; - Type t = args[1]; - *ret = unifier->subst(t); - } catch (std::exception &e) { - throw SubstitutionError(e.what()); - } - }); - } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py deleted file mode 100644 index c45e6ac4f732..000000000000 --- a/tests/python/relay/test_unifier.py +++ /dev/null @@ -1,495 +0,0 @@ -""" -Test the type unifier, which solves systems of equations -between incomplete types. -""" -import tvm -from tvm import relay -from tvm.relay.unifier import UnionFind, TypeUnifier -from tvm.relay.ir_builder import bool_type, uint_type, int_type, float_type, func_type -from tvm.relay import ir_builder as build - - -def test_insert_and_find(): - uf = relay.UnionFind() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - uf.insert(v1) - uf.insert(v2) - assert uf.find(v1) == v1 - assert uf.find(v2) == v2 - - -def test_insert_error(): - uf = relay.UnionFind() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - uf.insert(v1) - try: - uf.find(v2) - assert False - except: - return - - -def test_unify(): - uf = relay.UnionFind() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - v3 = relay.IncompleteType(ir.Kind.Type) - uf.insert(v1) - uf.insert(v2) - uf.insert(v3) - uf.unify(v1, v2) - rep = uf.find(v1) - assert (rep == v1 or rep == v2) - assert uf.find(v1) == rep - assert uf.find(v2) == rep - assert uf.find(v3) == v3 - assert v3 != rep - uf.unify(v1, v3) - new_rep = uf.find(v3) - assert (rep == v1 or rep == v2 or rep == v3) - assert uf.find(v1) == new_rep - assert uf.find(v2) == new_rep - assert uf.find(v3) == new_rep - - -def test_unify_multiple_levels(): - uf = relay.UnionFind() - v = [relay.IncompleteType(ir.Kind.Type) for _ in range(9)] - for var in v: - uf.insert(var) - uf.unify(v[0], v[1]) - uf.unify(v[0], v[2]) - uf.unify(v[3], v[4]) - uf.unify(v[4], v[5]) - uf.unify(v[6], v[7]) - uf.unify(v[6], v[8]) - rep1 = uf.find(v[0]) - rep2 = uf.find(v[3]) - rep3 = uf.find(v[6]) - assert (rep1 == v[0] or rep1 == v[1] or rep1 == v[2]) - assert (rep2 == v[3] or rep2 == v[4] or rep2 == v[5]) - assert (rep3 == v[6] or rep3 == v[7] or rep3 == v[8]) - for i in range(3): - assert uf.find(v[i]) == rep1 - assert uf.find(v[i + 3]) == rep2 - assert uf.find(v[i + 6]) == rep3 - # now unify two of the groups - uf.unify(v[1], v[4]) - new_rep1 = uf.find(v[0]) - new_rep2 = uf.find(v[6]) - assert (new_rep1 == v[0] or new_rep1 == v[1] or new_rep1 == v[2] - or new_rep1 == v[3] or new_rep1 == v[4] or new_rep1 == v[5]) - assert (new_rep2 == v[6] or new_rep2 == v[7] or new_rep2 == v[8]) - for i in range(6): - assert uf.find(v[i]) == new_rep1 - for i in range(3): - assert uf.find(v[i + 6]) == new_rep2 - -# We have checked that the basic machinery in the UnionFind works -# and now we will test the type unifier which will fill in holes -# between type equalities by the process of unification. - - -def unify_types(t1, t2): - unifier = relay.TypeUnifier() - return unifier.unify(t1, t2) - -# TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work - - -def test_unify_int(): - intty = int_type(1) - unified = unify_types(intty, intty) - assert intty == unified - - -def test_unify_bool(): - boolty = bool_type() - unified = unify_types(boolty, boolty) - assert boolty == unified - - -def test_unify_float(): - floatty = float_type(4) - unified = unify_types(floatty, floatty) - assert floatty == unified - - -def test_unify_incompatible_basetypes(): - bt = bool_type() - intty = int_type(32) - try: - unify_types(bt, intty) - assert False - except: - return - - -def test_unify_concrete_func_type(): - arr1 = func_type([int_type()], int_type()) - arr2 = func_type([int_type()], int_type()) - unified = unify_types(arr1, arr2) - assert unified == arr1 - - -def test_unify_func_type_with_holes(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.BaseType) - unifier.insert(v1) - unifier.unify(v1, bool_type()) - arr1 = func_type([int_type()], bool_type()) - arr2 = func_type([int_type()], v1) - unified = unifier.unify(arr1, arr2) - assert unified == arr1 - - v2 = relay.IncompleteType(ir.Kind.BaseType) - unifier.insert(v2) - unifier.unify(v2, int_type()) - arr3 = func_type([v2], bool_type()) - unified = unifier.unify(arr1, arr3) - assert unified == arr1 - - -def test_reject_incompatible_func_types(): - arr1 = func_type([int_type()], bool_type()) - arr2 = func_type([int_type(), bool_type()], bool_type()) - try: - unify_types(arr1, arr2) - assert False - except: - return - -# def test_unify_concrete_type_quantifiers(): -# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) -# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), int_type()) -# unified = unify_types(tq1, tq2) -# assert unified == tq1 - -# def test_unify_basetype_with_quantifier_error(): -# bt = bool_type() -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) -# try: -# unify_types(bt, tq) -# assert False -# except: -# return - - -def test_unify_typevars_with_each_other(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - v3 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - unifier.insert(v3) - unified = unifier.unify(v1, v2) - assert (unified == v1 or unified == v2) - assert unified != v3 - new_unified = unifier.unify(v1, v3) - assert (new_unified == v1 or new_unified == v2 or new_unified == v3) - - -def test_unify_typevars_with_basetype(): - unifier = relay.TypeUnifier() - bt = bool_type() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - unified1 = unifier.unify(v1, bt) - assert unified1 == bt - unified2 = unifier.unify(v1, v2) - assert unified2 == bt - - -def test_unify_compatible_typevars(): - unifier = relay.TypeUnifier() - bt = bool_type() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - unifier.unify(v1, bt) - unifier.unify(v2, bt) - # because types to which v1 and v2 have been assigned are compatible, - # this should proceed without problems - unified = unifier.unify(v1, v2) - assert unified == bt - -# def test_unify_incompatible_typevars(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# v2 = relay.IncompleteType(ir.Kind.Type) -# bt = bool_type() -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) -# unifier.insert(v1) -# unifier.insert(v2) -# unifier.unify(v1, bt) -# unifier.unify(v2, tq) -# # bt cannot be unified with tq, so unifying v1 and v2 should give an error -# try: -# unifier.unify(v1, v2) -# assert False -# except: -# return - -# def test_unify_typevar_with_quantifier(): -# unifier = relay.TypeUnifier() -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# unifier.insert(v1) -# unified = unifier.unify(v1, tq) -# assert unified == tq - -# def test_unify_typevars_inside_concrete_quantifier(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# unifier.insert(v1) -# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) -# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) -# unified = unifier.unify(tq1, tq2) -# assert unified == tq2 - - -def test_unify_concrete_tensors(): - bt = build.bool_dtype() - shape = tvm.convert([1, 2, 3]) - tt1 = relay.TensorType(shape, bt) - tt2 = relay.TensorType(shape, bt) - unified = unify_types(tt1, tt2) - assert unified == tt1 - - -def test_unify_tensor_shape_reject(): - bt = build.bool_dtype() - shape1 = tvm.convert([1, 2, 3]) - shape2 = tvm.convert([2, 3, 4]) - tt1 = relay.TensorType(shape1, bt) - tt2 = relay.TensorType(shape2, bt) - try: - unify_types(tt1, tt2) - assert False - except: - return - - -def test_unify_tensor_dtype_reject(): - bt1 = build.bool_dtype() - bt2 = build.int_dtype() - shape = tvm.convert([1, 2, 3]) - tt1 = relay.TensorType(shape, bt1) - tt2 = relay.TensorType(shape, bt2) - try: - unify_types(tt1, tt2) - assert False - except: - return - -# def test_unify_quantified_tensors(): -# x = TypeParam("x", ir.type.Kind.Shape) -# y = TypeParam("y", ir.type.Kind.Shape) -# tq1 = TypeQuantifier(x, relay.TensorType(bool_type(), x)) -# tq2 = TypeQuantifier(y, relay.TensorType(bool_type(), y)) -# unified = unify_types(tq1, tq2) -# assert unified == tq1 - -# a = TypeParam("a", ir.type.Kind.BaseType) -# b = TypeParam("b", ir.type.Kind.BaseType) -# tq3 = TypeQuantifier(a, relay.TensorType(a, make_shape([1, 2, 3]))) -# tq4 = TypeQuantifier(b, relay.TensorType(b, make_shape([1, 2, 3]))) -# unified = unify_types(tq3, tq4) -# assert unified == tq3 - -# def test_unify_concrete_products(): -# bt = bool_type() -# intty = int_type() -# pt1 = TupleType([bt, intty]) -# pt2 = TupleType([bt, intty]) -# unified = unify_types(pt1, pt2) -# assert unified == pt1 - -# def test_unify_products_reject_size(): -# bt = bool_type() -# intty = IntType(32) -# pt1 = TupleType([bt, bt, intty]) -# pt2 = TupleType([bt, intty]) -# try: -# unify_types(pt1, pt2) -# assert False -# except: -# return - -# def test_unify_products_reject_member(): -# bt = bool_type() -# intty = int_type() -# pt1 = TupleType([bt, bt]) -# pt2 = TupleType([bt, intty]) -# try: -# unify_types(pt1, pt2) -# assert False -# except: -# return - -# def test_unify_products_typevar(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# bt = bool_type() -# pt1 = TupleType([bt, bt]) -# pt2 = TupleType([v1, bt]) -# unifier.insert(v1) -# unified = unifier.unify(pt1, pt2) -# assert unified == pt1 - -# def test_unify_quantified_products(): -# x = TypeParam("x", ir.Kind.Type) -# y = TypeParam("y", ir.Kind.Type) -# p1 = TypeQuantifier(x, TupleType([int_type(), x])) -# p2 = TypeQuantifier(y, TupleType([int_type(), y])) -# unified = unify_types(p1, p2) -# assert unified == p1 - - -def test_subst_basetype(): - unifier = relay.TypeUnifier() - bt = bool_type() - assert bt == unifier.subst(bt) - - -def test_subst_simple_hole(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.BaseType) - bt = bool_type() - unifier.insert(v1) - unifier.unify(v1, bt) - assert unifier.subst(v1) == bt - - -def test_subst_typevar_for_typevar(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - - unifier.unify(v1, v2) - assert unifier.subst(v1) == unifier.subst(v2) - - -def test_subst_typevar_for_typevar_comm(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - - unifier.unify(v2, v1) - assert unifier.subst(v1) == unifier.subst(v2) - - -def test_subst_concrete_arrow(): - unifier = relay.TypeUnifier() - arr1 = func_type([int_type()], int_type()) - assert unifier.subst(arr1) == arr1 - - -def test_subst_arrow_with_holes(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.BaseType) - v2 = relay.IncompleteType(ir.Kind.BaseType) - unifier.insert(v1) - unifier.insert(v2) - unifier.unify(v1, int_type()) - unifier.unify(v2, bool_type()) - arr1 = func_type([v1], v2) - arr2 = func_type([int_type()], bool_type()) - assert unifier.subst(arr1) == arr2 - -# def test_subst_concrete_quantifier(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) -# unifier.insert(v1) -# unifier.unify(v1, tq) -# assert unifier.subst(v1) == tq - -# def test_subst_quantifier_with_holes(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# v2 = relay.IncompleteType(ir.Kind.Type) -# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) -# intty = int_type() -# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) - # unifier.insert(v1) - # unifier.insert(v2) - # unifier.unify(v2, intty) - # unifier.unify(v1, tq1) - # assert unifier.subst(v1) == tq2 - - -def test_subst_concrete_tensor(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - tt = relay.TensorType(tvm.convert([1, 2, 3]), 'uint1') - unifier.unify(v1, tt) - assert unifier.subst(v1) == tt - -# def test_subst_concrete_product(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# unifier.insert(v1) -# bt = bool_type() -# pt = TupleType([bt, bt]) -# unifier.unify(v1, pt) -# assert unifier.subst(v1) == pt - -# def test_subst_product_with_holes(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# v2 = relay.IncompleteType(ir.Kind.Type) -# v3 = relay.IncompleteType(ir.Kind.Type) -# unifier.insert(v1) -# unifier.insert(v2) -# unifier.insert(v3) - -# tt1 = relay.TensorType(int_type(), tvm.convert([])) -# tt2 = relay.TensorType(FloatType(32), tvm.convert([])) -# pt1 = TupleType([tt1, v2, v3]) -# unifier.unify(v2, tt2) -# unifier.unify(v3, v2) -# unifier.unify(v1, pt1) -# pt2 = TupleType([tt1, tt2, tt2]) -# assert unifier.subst(v1) == pt2 - -# def test_subst_concrete_ref(): -# unifier = relay.TypeUnifier() -# rt = RefType(bool_type()) -# assert unifier.subst(rt) == rt - -# def test_subst_ref_with_hole(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# unifier.insert(v1) - -# unifier.unify(v1, bool_type()) -# rt1 = RefType(v1) -# rt2 = RefType(bool_type()) -# assert unifier.subst(rt1) == rt2 - -# def test_typevar_on_lhs(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# v2 = relay.IncompleteType(ir.Kind.Type) -# bt = bool_type() -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) -# unifier.insert(v1) -# unifier.insert(v2) -# unified1 = unifier.unify(bt, v1) -# assert unified1 == bt -# unified2 = unifier.unify(tq, v2) -# assert unified2 == tq -# assert unifier.subst(v1) == bt -# assert unifier.subst(v2) == tq From 5b3573bb0f8f80082eaf123e5e074da516963290 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 24 Aug 2018 10:51:02 -0700 Subject: [PATCH 36/88] [OP] Structral refactor --- include/tvm/base.h | 6 --- include/tvm/relay/expr.h | 6 ++- include/tvm/relay/op.h | 1 - python/tvm/relay/__init__.py | 6 +++ python/tvm/relay/op.py | 4 +- python/tvm/relay/op/__init__.py | 6 +++ python/tvm/relay/op/_make.py | 4 ++ python/tvm/relay/op/_tensor.py | 4 ++ python/tvm/relay/op/registry.py | 1 + python/tvm/relay/op/tensor.py | 60 +++++++++++++++++++++++++++++ src/relay/compiler/unifier.cc | 2 +- src/relay/{ => ir}/base.cc | 0 src/relay/{ => ir}/expr.cc | 0 src/relay/{ => ir}/op.cc | 0 src/relay/{ => ir}/type.cc | 0 src/relay/op/tensor/elemwise.cc | 46 ++++++++++++++++++++-- tests/python/relay/test_relay_op.py | 8 +++- 17 files changed, 137 insertions(+), 17 deletions(-) create mode 100644 python/tvm/relay/op/__init__.py create mode 100644 python/tvm/relay/op/_make.py create mode 100644 python/tvm/relay/op/_tensor.py create mode 100644 python/tvm/relay/op/registry.py create mode 100644 python/tvm/relay/op/tensor.py rename src/relay/{ => ir}/base.cc (100%) rename src/relay/{ => ir}/expr.cc (100%) rename src/relay/{ => ir}/op.cc (100%) rename src/relay/{ => ir}/type.cc (100%) diff --git a/include/tvm/base.h b/include/tvm/base.h index be848b34cd43..464259bc0527 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -134,11 +134,5 @@ struct NodeFactoryReg { */ #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) -#define TVM_REGISTER_NODE_TYPE(TypeName) \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ - ::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \ - .set_body([]() { return std::make_shared(); }) - - } // namespace tvm #endif // TVM_BASE_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index a29c8486ffb6..a4a683297ea5 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -282,8 +282,10 @@ class CallNode : public ExprNode { v->Visit("span", &span); } - TVM_DLL static Call make(Expr op, Array args, Attrs attrs, - Array ty_args); + TVM_DLL static Call make(Expr op, + Array args, + Attrs attrs = Attrs(), + Array ty_args = Array()); static constexpr const char* _type_key = "relay.Call"; TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index f7e1cfbbc8c2..cae3d9db6920 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -276,7 +276,6 @@ class OpMap { const GenericOpMap& map_; }; - // internal macros to make #define RELAY_REGISTER_VAR_DEF \ static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry & __make_ ## RelayOp diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 037d71854689..f94b572f6b44 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -4,6 +4,10 @@ from . import expr from . import op +# import all operators in the loop namespace +from .op import * + + # Span Span = base.Span @@ -18,6 +22,7 @@ # Expr Constant = expr.Constant Tuple = expr.Tuple +# TODO: GlobalVar, LocalVar-> var LocalVar = expr.LocalVar GlobalVar = expr.GlobalVar Param = expr.Param @@ -25,3 +30,4 @@ Call = expr.Call Let = expr.Let If = expr.If +Var = LocalVar diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py index dae498b66c12..d54edf47c5ee 100644 --- a/python/tvm/relay/op.py +++ b/python/tvm/relay/op.py @@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs import sys -from .._ffi.function import _init_api + from .._ffi.node import convert_to_node from . import _make from ..make import node as _make_node @@ -33,5 +33,5 @@ def _init_ops(): f = _create_op(name.value) setattr(module, f.__name__, f) -_init_api("relay.op", __name__) + _init_ops() diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py new file mode 100644 index 000000000000..02e49ec40ff8 --- /dev/null +++ b/python/tvm/relay/op/__init__.py @@ -0,0 +1,6 @@ +"""Relay core operators.""" +# operator defs +from .tensor import * + +# operator registry +from . import _tensor diff --git a/python/tvm/relay/op/_make.py b/python/tvm/relay/op/_make.py new file mode 100644 index 000000000000..79c86cbb0254 --- /dev/null +++ b/python/tvm/relay/op/_make.py @@ -0,0 +1,4 @@ +"""Constructor APIs""" +from ..._ffi.function import _init_api + +_init_api("relay.op._make", __name__) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py new file mode 100644 index 000000000000..08dedee0923c --- /dev/null +++ b/python/tvm/relay/op/_tensor.py @@ -0,0 +1,4 @@ +"""Backend compiler related feature regsitration""" + + + diff --git a/python/tvm/relay/op/registry.py b/python/tvm/relay/op/registry.py new file mode 100644 index 000000000000..d7426429ef6f --- /dev/null +++ b/python/tvm/relay/op/registry.py @@ -0,0 +1 @@ +"""Mechanism to work with operator registry.""" diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py new file mode 100644 index 000000000000..7155db3a4cd5 --- /dev/null +++ b/python/tvm/relay/op/tensor.py @@ -0,0 +1,60 @@ +"""Basic tensor operations.""" +from __future__ import absolute_import as _abs +from . import _make + +# We create a wrapper function for each operator in the +# python side to call into the positional _make.OpName function. +# +# We make this decision so that we can: +# - Have declare python docstring for each function +# - Enable keyword arguments easily +# - Not put too much burden on FFI to support complicated features +# like default value and keyword arguments + + +def log(data): + """Take log of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.log(data) + + +def exp(data): + """Take exp of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.exp(data) + + +def sqrt(data): + """Take sqrt of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.sqrt(data) diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc index 5c0fbcf3ec71..b7cc296cc5db 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/compiler/unifier.cc @@ -9,7 +9,7 @@ #include "tvm/relay/compiler/alpha_eq.h" #include "./unifier.h" #include "./type_visitor.h" -#include "./type_subst.h" +//#include "./type_subst.h" // #include "tvm/relay/typeck/kindchecker.h" namespace tvm { diff --git a/src/relay/base.cc b/src/relay/ir/base.cc similarity index 100% rename from src/relay/base.cc rename to src/relay/ir/base.cc diff --git a/src/relay/expr.cc b/src/relay/ir/expr.cc similarity index 100% rename from src/relay/expr.cc rename to src/relay/ir/expr.cc diff --git a/src/relay/op.cc b/src/relay/ir/op.cc similarity index 100% rename from src/relay/op.cc rename to src/relay/ir/op.cc diff --git a/src/relay/type.cc b/src/relay/ir/type.cc similarity index 100% rename from src/relay/type.cc rename to src/relay/ir/type.cc diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 8b759bfbc07c..79301a7fac24 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -3,21 +3,59 @@ * \file elemwise.cc * \brief Elementwise operators. */ +#include #include namespace tvm { namespace relay { -RELAY_REGISTER_OP("log") +// Quick helper macro +// - Expose a positional make function to construct the node. +// - Register op to the registry. +// +// We make the decision to always only expose positional argument. +// We will do rewrapping in the frontend to support language +// sugars such as keyword arguments and default value. +// +#define RELAY_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ + .set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") + + +RELAY_REGISTER_UNARY_OP("log") .describe(R"code(Returns the log input array, computed element-wise. .. math:: log(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor."); +.set_support_level(1); + + +RELAY_REGISTER_UNARY_OP("exp") +.describe(R"code(Returns the exp input array, computed element-wise. + +.. math:: + \exp(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1); + + +RELAY_REGISTER_UNARY_OP("sqrt") +.describe(R"code(Returns the sqrt input array, computed element-wise. + +.. math:: + sqrt(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_relay_op.py b/tests/python/relay/test_relay_op.py index 93316da8ec41..4235dd918d93 100644 --- a/tests/python/relay/test_relay_op.py +++ b/tests/python/relay/test_relay_op.py @@ -1,7 +1,13 @@ from tvm import relay def test_op_level1(): - assert relay.op.log + x = relay.Var("x") + + for op_name in ["log", "exp", "sqrt"]: + y = getattr(relay, op_name)(x) + assert y.op.name == op_name + assert y.op.support_level == 1 + assert y.args[0] == x if __name__ == "__main__": From aebf1249fc30f70900b7066f56487d8f1a8a0a7f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 14:34:28 -0700 Subject: [PATCH 37/88] Add type_subst back --- src/relay/compiler/type_subst.cc | 39 ++++++++++++++++++++++++++++++++ src/relay/compiler/type_subst.h | 19 ++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 src/relay/compiler/type_subst.cc create mode 100644 src/relay/compiler/type_subst.h diff --git a/src/relay/compiler/type_subst.cc b/src/relay/compiler/type_subst.cc new file mode 100644 index 000000000000..6650f59bad51 --- /dev/null +++ b/src/relay/compiler/type_subst.cc @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_subst.cc + * \brief Function for substituting a concrete type in place of a type ID + */ +#include "./type_subst.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +struct TypeSubst : TypeFVisitor { + tvm::Map subst_map; + + explicit TypeSubst(tvm::Map subst_map) + : subst_map(subst_map) {} + + Type VisitType_(const TypeParamNode *op) override { + auto id = GetRef(op); + if (subst_map.find(id) != subst_map.end()) { + return this->subst_map[id]; + } else { + return id; + } + } +}; + +Type type_subst(const Type &type, const TypeParam &target, const Type &subst) { + TypeSubst ty_sub({ {target, subst} }); + return ty_sub.VisitType(type); +} + +Type type_subst(const Type &type, tvm::Map subst_map) { + TypeSubst ty_sub(subst_map); + return ty_sub.VisitType(type); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/compiler/type_subst.h b/src/relay/compiler/type_subst.h new file mode 100644 index 000000000000..0bf0de5a4b85 --- /dev/null +++ b/src/relay/compiler/type_subst.h @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file typeck/type_subst.h + * \brief Utility function for substituting types + */ +#ifndef TVM_RELAY_TYPECK_TYPE_SUBST_H_ +#define TVM_RELAY_TYPECK_TYPE_SUBST_H_ + +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +Type type_subst(const Type & type, const TypeParam & target, const Type & subst); +Type type_subst(const Type &type, tvm::Map subst_map); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPECK_TYPE_SUBST_H_ From 4bf87799a47fa5cc08486224046a77da58261490 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 14:35:06 -0700 Subject: [PATCH 38/88] Clean up code while refactoring inference --- include/tvm/relay/compiler/environment.h | 15 +- include/tvm/relay/expr.h | 5 + include/tvm/relay/op.h | 20 ++- python/tvm/relay/env.py | 105 ++++-------- python/tvm/relay/ir_builder.py | 45 +++--- python/tvm/relay/op.py | 11 +- src/relay/compiler/environment.cc | 2 +- src/relay/compiler/type_infer.cc | 195 ++++++++++++----------- src/relay/ir/expr.cc | 6 +- src/relay/op/tensor/elemwise.cc | 5 +- tests/python/relay/test_typechecker.py | 9 +- 11 files changed, 214 insertions(+), 204 deletions(-) diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index 5b33e781b399..d5a3ddd73f77 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -40,6 +40,7 @@ class EnvironmentNode : public RelayNode { private: /*! A map from string names to GlobalIds, ensures global uniqueness. */ tvm::Map global_map_; + tvm::Map type_func_map_; // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ @@ -57,14 +58,12 @@ class EnvironmentNode : public RelayNode { void VisitAttrs(tvm::AttrVisitor* v) final {} TVM_DLL static Environment make( - std::unordered_map global_funcs); - - /*! Add an operator to the Enviroment. */ - void register_op(const Op& op); - void add(const GlobalVar& var, const Function & func, bool update = false); - void try_add(const GlobalVar& var, const Function & func, bool update=false); - void update(const GlobalVar& var, const Function & func); - void remove(const GlobalVar& var); + tvm::Map global_funcs); + + void Add(const GlobalVar& var, const Function & func, bool update = false); + void TryAdd(const GlobalVar& var, const Function & func, bool update=false); + void Update(const GlobalVar& var, const Function & func); + void Remove(const GlobalVar& var); GlobalVar GetGlobalVar(const std::string& str); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index a4a683297ea5..ff11a41a6e5f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -364,6 +364,11 @@ class IfNode : public ExprNode { RELAY_DEFINE_NODE_REF(If, IfNode, Expr); +// template +// T Downcast(U u) { + +// } + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index cae3d9db6920..be81f54ecd69 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -14,6 +14,7 @@ #include #include "./base.h" +#include "./type.h" #include "./expr.h" #include "../attrs.h" @@ -33,6 +34,8 @@ class OpNode : public relay::ExprNode { public: /*! \brief name of the operator */ std::string name; + + Type op_type; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. @@ -67,7 +70,7 @@ class OpNode : public relay::ExprNode { } static constexpr const char* _type_key = "relay.Op"; - TVM_DECLARE_NODE_TYPE_INFO(OpNode, Node); + TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); private: // friend class @@ -145,6 +148,13 @@ class OpRegistry { inline OpRegistry& add_argument(const std::string &name, const std::string &type, const std::string &description); + /*! + * \brief Attach the type function corresponding to the return type. + * \param ty_func The type function to register for the return type. + * \return reference to self. + */ + inline OpRegistry& add_type_func(const std::string & type_func_name); + /*! * \brief Set the type key of attributes. * \param type_key The type of of the attrs field.x @@ -329,6 +339,14 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, return *this; } + inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name) { + auto type_func = TypeFunctionNode::make(type_func_name, 0); + for (auto arg : get()->arguments) { + std::cout << arg << std::endl; + } + return *this; + } + inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) get()->num_inputs = n; return *this; diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 9bd63476f1fb..c63197fa8509 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -1,98 +1,57 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import """A global environment storing everything needed to interpret or compile a Realy program.""" from typing import Union, List -from relay.ir import register_relay_node, NodeBase -from relay.ir import GlobalId, OperatorId, Item, FileId, Span, ShapeExtension -from relay.ir import Operator, Defn -from relay._env import * +from .base import register_relay_node, NodeBase +from . import _make +# from relay.ir import GlobalId, OperatorId, Item, FileId, Span, ShapeExtension +# from relay.ir import Operator, Defn +# from relay._env import * import tvm # Move me to C++ if possible. __tgt_host__ = __tgt__ = "llvm" __relay_tvm_context__ = tvm.cpu() -ADD_ID = "__add__" -SUB_ID = "__sub__" -MUL_ID = "__mul__" -DIV_ID = "__div__" -NEG_ID = "__neg__" -LT_ID = "__lt__" -LE_ID = "__le__" -GT_ID = "__gt__" -GE_ID = "__ge__" -EQ_ID = "__eq__" -NE_ID = "__ne__" - @register_relay_node class Environment(NodeBase): """The global Relay environment containing definitions, primitives, options, and more. """ - def add(self, item: Item) -> None: - return Environment_add(self, item) - - def global_id(self, name: str) -> GlobalId: - return Environment_global_id(self, name) - - def operator_id(self, name: str) -> OperatorId: - return Environment_operator_id(self, name) - - def lookup(self, ident: Union[GlobalId, OperatorId]) -> Item: - if isinstance(ident, OperatorId): - return Environment_lookup_operator(self, ident) - else: - return Environment_lookup_global(self, ident) - - def add_source(self, file_name: str, source: str) -> FileId: - return Environment_add_source(self, file_name, source) - - def report_error(self, message: str, span: Span) -> None: - return Environment_report_error(self, message, span) - - def register_shape_ext(self, ext: ShapeExtension) -> None: - return Environment_register_shape_ext(self, ext) - - def display_errors(self) -> None: - return Environment_display_errors(self) - - def operators(self) -> List[Operator]: - return Environment_get_operators(self) - - def defns(self) -> List[Defn]: - return Environment_get_defns(self) - - def tvm_context(self): - return __relay_tvm_context__ + def __init__(self, funcs) -> None: + self.__init_handle_by_constructor__(_make.Environment, funcs) - def add_id(self) -> OperatorId: - return self.operator_id(ADD_ID) + # def add(self, item: Item) -> None: + # return Environment_add(self, item) - def sub_id(self) -> OperatorId: - return self.operator_id(SUB_ID) + # def global_id(self, name: str) -> GlobalId: + # return Environment_global_id(self, name) - def mul_id(self) -> OperatorId: - return self.operator_id(MUL_ID) + # def operator_id(self, name: str) -> OperatorId: + # return Environment_operator_id(self, name) - def div_id(self) -> OperatorId: - return self.operator_id(DIV_ID) + # def lookup(self, ident: Union[GlobalId, OperatorId]) -> Item: + # if isinstance(ident, OperatorId): + # return Environment_lookup_operator(self, ident) + # else: + # return Environment_lookup_global(self, ident) - def neg_id(self) -> OperatorId: - return self.operator_id(NEG_ID) + # def add_source(self, file_name: str, source: str) -> FileId: + # return Environment_add_source(self, file_name, source) - def lt_id(self) -> OperatorId: - return self.operator_id(LT_ID) + # def report_error(self, message: str, span: Span) -> None: + # return Environment_report_error(self, message, span) - def le_id(self) -> OperatorId: - return self.operator_id(LE_ID) + # def register_shape_ext(self, ext: ShapeExtension) -> None: + # return Environment_register_shape_ext(self, ext) - def gt_id(self) -> OperatorId: - return self.operator_id(GT_ID) + # def display_errors(self) -> None: + # return Environment_display_errors(self) - def ge_id(self) -> OperatorId: - return self.operator_id(GE_ID) + # def operators(self) -> List[Operator]: + # return Environment_get_operators(self) - def eq_id(self) -> OperatorId: - return self.operator_id(EQ_ID) + # def defns(self) -> List[Defn]: + # return Environment_get_defns(self) - def ne_id(self) -> OperatorId: - return self.operator_id(NE_ID) + # def tvm_context(self): + # return __relay_tvm_context__ diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index af83c9948be2..07927aef7d24 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,8 +1,8 @@ from typing import Any import numpy as np import tvm -from . import type as ty -from . import expr +from .type import FloatType, IntType, BoolType, UIntType, FuncType +from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function from . import op as _op class ExprBuilder(): @@ -10,7 +10,7 @@ def __init__(self, expr): self.expr = expr def __call__(self, *args): - return ExprBuilder(mk.Call(self.expr, list(args), None, None)) + return ExprBuilder(Call(self.expr, list(args), None, None)) def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: """Convert Python values into the appropriate types @@ -30,12 +30,12 @@ def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: # raise Exception(f"can't convert {type(arg)} to a Relay AST") raise Exception(f"unsupported argument type {type(arg)}") -def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> expr.Expr: +def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: if isinstance(arg, tuple): raise Exception("..") else: value = convert(arg, ctxt) - return ExprBuilder(mk.Constant(value)) + return ExprBuilder(Constant(value)) class WithScope(object): """Auxiliary scope with""" @@ -61,11 +61,18 @@ def __init__(self, params, ret_type, body, type_params): def param_ids(self): return [p.var for p in self.params] + def to_func(self): + return Function( + self.params, + self.ret_type, + self.body, + self.type_params) + def _mk_let(bindings, ret_value): let_expr = ret_value for var, value in reversed(list(bindings.items())): - let_expr = mk.Let(var, value, let_expr, None) + let_expr = Let(var, value, let_expr, None) return let_expr @@ -79,14 +86,14 @@ def __init__(self): def bind(self, name, type, value): - lv = mk.LocalVar(name) + lv = LocalVar(name) self.scopes[-1][name] = lv self.bindings[-1][lv] = value return lv def let(self, name, value, value_type=None): - if not (isinstance(value, expr.Expr) or isinstance(value, ExprBuilder)): + if not (isinstance(value, Expr) or isinstance(value, ExprBuilder)): value = into_ast(value) if isinstance(value, ExprBuilder): @@ -97,9 +104,9 @@ def let(self, name, value, value_type=None): def function(self, *params): relay_params = [] for name, ty in params: - lv = mk.LocalVar(name) + lv = LocalVar(name) self.scopes[-1][name] = lv - relay_params.append(mk.Param(lv, ty)) + relay_params.append(Param(lv, ty)) # self.params.append(relay_params) @@ -108,7 +115,10 @@ def function(self, *params): def _on_exit(): bindings = self.bindings.pop() scope = self.scopes.pop() - # params = self.params.pop() + ret_value = self.ret_value + body = _mk_let(bindings, ret_value) + self.ret_value = None + pfunc.body = body return WithScope(pfunc, _on_exit) @@ -124,9 +134,6 @@ def ret(self, x): def fn_params(self): pass - def op(self, name): - pass - def get(self): """Get the full program""" bindings = self.bindings.pop() @@ -152,16 +159,16 @@ def int_dtype(): return 'uint1' def int_type(bits=32, lanes=1): - return mk.IntType(bits, lanes) + return IntType(bits, lanes) def uint_type(bits=32, lanes=1): - return mk.UIntType(bits, lanes) + return UIntType(bits, lanes) def float_type(bits=32, lanes=1): - return mk.FloatType(bits, lanes) + return FloatType(bits, lanes) def bool_type(lanes=1): - return mk.BoolType(lanes) + return BoolType(lanes) def func_type(args, ret_type, type_params=[], type_constraints=[]): - return mk.FuncType(args, ret_type, type_params, type_constraints) + return FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py index d54edf47c5ee..d36a433e1e85 100644 --- a/python/tvm/relay/op.py +++ b/python/tvm/relay/op.py @@ -6,6 +6,12 @@ from .._ffi.node import convert_to_node from . import _make from ..make import node as _make_node +from .expr import Expr, Call +from .base import register_relay_node + +@register_relay_node +class Op(Expr): + pass def _create_op(op_name): op = _GetOp(op_name) @@ -19,8 +25,9 @@ def _create_op(op_name): def _op_func(*args, **kwargs): args = convert_to_node(args) # Need work to make sure constructor matches - return _make.Call(op, args, - attrs = _make.node(attrs_type_key, **kwargs)) + # can support kwargs later + attrs = _make_node(attrs_type_key, **kwargs) + return Call(op, args, None, []) _op_func.__name__ = op.name return _op_func diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc index 7ce0785f4f8f..a1c6b31076e3 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/compiler/environment.cc @@ -18,7 +18,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; Environment EnvironmentNode::make( - std::unordered_map global_funcs) { + tvm::Map global_funcs) { std::shared_ptr n = std::make_shared(); n->items = std::move(global_funcs); return Environment(n); diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 7304bdabe486..40f14b517951 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -94,7 +94,7 @@ class TypeInferencer : private ExprFunctor { CheckedExpr Infer(const Expr & expr); - Type instantiate(Type t, tvm::Array &ty_args); + Type instantiate(FuncType fn_ty, tvm::Array &ty_args); void report_error(const std::string & msg, Span sp); [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); @@ -192,10 +192,9 @@ class TypeInferencer : private ExprFunctor { throw Error("TupleNode NYI"); } - CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *op) { - // Param param = GetRef(op); - // return { resolve(param->type); - throw Error("ParamNode NYI"); + CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { + auto rtype = resolve(param->type); + return { ParamNode::make(param->var, rtype), rtype }; } // // We should probably generalize the subst code. @@ -236,25 +235,32 @@ class TypeInferencer : private ExprFunctor { // }; CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { - throw Error("FunctionNode NYI"); - // // enter params into context - // auto fn_type = this->with_frame([&]() { - // std::vector arg_types; - // for (auto arg : f->params) { - // this->Check(arg); - // Type arg_type; - // // if arg type can be simply evaluated, try it - // // should be replaced with symbolic evaluation once it exists, - // // you will not have attr information at this point - // try { - // arg_type = simple_eval_shape(arg->type); - // } catch (const dmlc::Error &e) { - // this->report_error(e.what(), arg->span); - // arg_type = arg->type; - // } - // arg_types.push_back(arg_type); - // this->local_stack.insert(arg->id, arg_type); - // } + // First we add the parameters to the context allowing us to check their + // types. + + // TODO(@jroesch): support polymorphism + + std::vector param_types; + std::vector params; + + return this->with_frame([&]() -> CheckedExpr { + for (auto param : f->params) { + CheckedExpr checked_param = this->Infer(param); + Type arg_type; + param_types.push_back(checked_param.type); + params.push_back(GetRef(checked_param.expr.as())); + this->local_stack.insert(param->var, checked_param.type); + } + + auto checked_body = this->Infer(f->body); + auto inferred_rtype = checked_body.type; + auto annotated_rtype = resolve(f->ret_type); + + auto unified_rtype = this->unify(inferred_rtype, annotated_rtype, f->span); + + return { FunctionNode::make(params, unified_rtype, checked_body.expr, {}), + FuncTypeNode::make(param_types, unified_rtype, {}, {}) }; + }); // // typecheck body and ensure that it matches stated return type // // TODO(sslyu): should the unified return type override the annotated @@ -332,95 +338,96 @@ class TypeInferencer : private ExprFunctor { // } // return fn_type; - } CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { return this->VisitFunction(GetRef(op), false); } - // Type TypeInferencer::instantiate(Type t, tvm::Array &ty_args) { - // const TypeQuantifierNode *ty_quant; - // while ((ty_quant = t.as())) { - // TypeParam id = ty_quant->id; - // TypeVar fresh = TypeVarNode::make(id->kind); - // this->unifier->insert(fresh); - // ty_args.push_back(fresh); - // t = type_subst(ty_quant->boundType, id, fresh); - // } + Type TypeInferencer::instantiate(FuncType fn_ty, tvm::Array &ty_args) { + // const TypeQuantifierNode *ty_quant; + // while ((ty_quant = t.as())) { + // TypeParam id = ty_quant->id; + // TypeVar fresh = TypeVarNode::make(id->kind); + // this->unifier->insert(fresh); + // ty_args.push_back(fresh); + // t = type_subst(ty_quant->boundType, id, fresh); + // } - // if (!check_kind(t)) { - // this->fatal_error("Kind rules broken when instantiating type - // variables", - // t->span); - // } + // if (!check_kind(t)) { + // this->fatal_error("Kind rules broken when instantiating type + // variables", + // t->span); + // } - // return t; - // } + // return t; + } CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { - throw Error("CallNode"); - // Call c = GetRef(op); - // Type fn_ty = this->Check(c->fn); - - // RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl - // << "fn_ty=" << fn_ty << std::endl; - - // // for each type id, insert a type variable and unify with the argument - // types - // // in order - // // to obtain the concrete instantiation - // tvm::Array ty_args; - // if (const TypeQuantifierNode *ty_quant = fn_ty.as()) - // { - // fn_ty = instantiate(GetRef(ty_quant), ty_args); - // } + Call c = GetRef(op); - // if (!fn_ty.as()) { - // this->fatal_error("only expressions with function types can be called", - // c->fn->span); - // } + auto checked_op = this->Infer(c->op); - // // evaluate all shapes up front (require that types be fully concrete) - // Type evaluated = evaluate_concrete_shape(fn_ty, op->attrs); - // std::vector arg_types; + RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl + << "fn_ty=" << fn_ty << std::endl; - // TypeArrow arrow = GetRef(evaluated.as()); - // // TODO(sslyu): figure out how to handle type ids - // // fn_ty = instantiate(fn_ty, ty_args); - // for (auto arg : c->args) { - // auto ty = this->Check(arg); - // arg_types.push_back(ty); - // } + auto fn_ty_node = checked_op.expr.as(); - // auto type_arity = arrow->arg_types.size(); - // auto number_of_args = arg_types.size(); - // if (type_arity != number_of_args) { - // if (type_arity < number_of_args) { - // this->fatal_error("the function is provided too many arguments", - // c->span); - // } else { - // this->fatal_error("the function is provided too few arguments", - // c->span); - // } - // } + if (!fn_ty_node) { + this->fatal_error("only expressions with function types can be called", c->fn->span); + } - // for (size_t i = 0; i < arrow->arg_types.size(); i++) { - // this->unify(arrow->arg_types[i], arg_types[i], c->args[i]->span); - // } + // We now have a function type. + FuncType fn_ty = GetRef(fn_ty_node); - // // After we unify the arguments we should know more about the type - // // arguments, let's run a quick pass over them to find new - // representatives. for (size_t i = 0; i < ty_args.size(); i++) { - // ty_args.Set(i, this->unifier->subst(ty_args[i])); - // } + tvm::Array ty_args; + if (ty_args.size() != 0) { + throw Error("found manually suplied type args, not supported"); + } + + fn_ty = instantiate(fn_ty, ty_args); + + std::vector arg_types; + + + // TODO(sslyu): figure out how to handle type ids + // fn_ty = instantiate(fn_ty, ty_args); + for (auto arg : c->args) { + auto checked_arg = this->Infer(arg); + arg_types.push_back(checked_arg.type); + } + + auto type_arity = fn_ty->arg_types.size(); + auto number_of_args = arg_types.size(); + + if (type_arity != number_of_args) { + if (type_arity < number_of_args) { + this->fatal_error("the function is provided too many arguments", + c->span); + } else { + this->fatal_error("the function is provided too few arguments", + c->span); + } + } + + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { + this->unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); + } + + // After we unify the arguments we should know more about the type + // arguments, let's run a quick pass over them to find new + // representatives. + + for (size_t i = 0; i < ty_args.size(); i++) { + ty_args.Set(i, this->unifier->subst(ty_args[i])); + } - // // Write the type arguments into the call node, recording what inference - // // solves. This solution might need some work. - // c->ty_args = ty_args; + // Write the type arguments into the call node, recording what inference + // solves. This solution might need some work. + c->ty_args = ty_args; - // return arrow->ret_type; + return { new_call, call_type } } // Type TypeInferencer::VisitExpr_(const DebugNode *op) { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 38df81940e48..3a3ef1b52604 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -114,7 +114,11 @@ TVM_REGISTER_API("relay._make.Function") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionNode *node, tvm::IRPrinter *p) { - p->stream << "FunctionNode(TODO)"; + p->stream << "FunctionNode(" << + node->params << ", " << + node->ret_type << ", " << + node->body << ", " << + node->type_params << ")"; }); Call CallNode::make(Expr op, Array args, Attrs attrs, diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 79301a7fac24..50c864650ff4 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -36,6 +36,9 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1); +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.add_type_func("Broadcast"); RELAY_REGISTER_UNARY_OP("exp") @@ -57,5 +60,5 @@ RELAY_REGISTER_UNARY_OP("sqrt") )code" TVM_ADD_FILELINE) .set_support_level(1); -} // namespace relay +} // namespace relayv } // namespace tvm diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index 6a16aadcb002..9c050ecd62d0 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -2,10 +2,11 @@ for expressions. """ from tvm.relay.type_infer import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type, op +from tvm.relay.ir_builder import IRBuilder, float_type, op, func_type +from tvm.relay.env import Environment def has_type(expr, typ): - env = mk.Environment({}) + env = Environment({}) checked_expr = check_expr(env, expr) return checked_expr.checked_type() == typ @@ -20,11 +21,11 @@ def test_monomorphic_let(): def test_single_op(): - "Program: fn (x : int32) { let t1 = f(x); t1 }" + "Program: fn (x : float32) { let t1 = f(x); t1 }" b = IRBuilder() f = op('log') with b.function(('x', float_type())) as func: x, = func.param_ids() t1 = b.let('t1', f(x)) b.ret(t1) - import pdb; pdb.set_trace() + assert has_type(func.to_func(), func_type([float_type()], float_type())) From 05051d3e99c50dac2557e22378989a501bb644a1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 14:35:19 -0700 Subject: [PATCH 39/88] Restore old TVM backend code --- python/tvm/relay/tvm_rts_backend.py | 239 ++++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 python/tvm/relay/tvm_rts_backend.py diff --git a/python/tvm/relay/tvm_rts_backend.py b/python/tvm/relay/tvm_rts_backend.py new file mode 100644 index 000000000000..137230ace63a --- /dev/null +++ b/python/tvm/relay/tvm_rts_backend.py @@ -0,0 +1,239 @@ +"""A compiler from Relay programs to TVM's graph runtime. +""" +import json +from typing import Dict, Any, List, Tuple + +import attr + +from relay.frontend import get_env +from . import ir +from .tyck import get_checked_type +from .opt import AbstractExprVisitor, compile_ops_to_module +from ._make import Operator_is_generic + + +@attr.s(auto_attribs=True) +class NodeRef: + ident: int + index: int = 0 + version: int = 0 + + def to_json(self) -> Any: + return [self.ident, self.index, self.version] + + +@attr.s(auto_attribs=True) +class Node(): + name: str + attrs: Dict[str, Any] + is_output: bool + + def to_json(self) -> Any: + raise Exception("Abstract method, please implement me.") + + +@attr.s(auto_attribs=True) +class InputNode(Node): + """An input node in the graph representation we lower to before NNVM's graph.""" + is_output: bool = False + + def to_json(self): + return { + "op": "null", + "name": self.name, + "inputs": [] + } + + +@attr.s(auto_attribs=True) +class OpNode(Node): + """An operator node in the graph representation we lower to before NNVM's graph.""" + op_name: str + inputs: List[NodeRef] + op_attrs: Dict[str, Any] + is_output: bool = False + + def to_json(self) -> Any: + attrs = dict.copy(self.op_attrs) + # Extend ops with extra info. + attrs['func_name'] = self.op_name + # When do we flatten? + attrs['flatten_data'] = "0" + # Fix me! + attrs['num_inputs'] = str(len(self.inputs)) + attrs['num_outputs'] = "1" + + return { + "op": "tvm_op", + "name": self.name, + "attrs": attrs, + "inputs": self.inputs + } + + +def from_tensor(typ: ir.TensorType) -> Tuple[str, List[int]]: + dtype = typ.dtype.dtype + shape = typ.shape + dims = [] + for dim in shape.shapes: + dims.append(dim.value) + return dtype, dims + + +class TVMRTSCompiler(AbstractExprVisitor[NodeRef]): + """The compiler from Relay to the TVM runtime system.""" + nodes: List[Node] + id_map: Dict[ir.LocalId, NodeRef] + + def __init__(self) -> None: + self.nodes = [] + self.id_map = {} + + def add_node(self, node: Node) -> NodeRef: + self.nodes.append(node) + ident = len(self.nodes) - 1 + return NodeRef(ident) + + def add_binding(self, ident: ir.LocalId, ref: NodeRef) -> None: + self.id_map[ident] = ref + + def let_bind(self, ident: ir.LocalId, node: Node) -> NodeRef: + ref = self.add_node(node) + self.add_binding(ident, ref) + return ref + + def get_node(self, ref: NodeRef) -> Node: + return self.nodes[ref.ident] + + def lookup(self, ident: ir.LocalId) -> NodeRef: + return self.id_map[ident] + + def compile(self, func: ir.Function) -> None: + """Compile a single function into a graph.""" + # TODO: (@jroesch) Restore me + # assert len(fn.ty_params) == 0 + + # First we convert all the parameters into input nodes. + params = func.params + + for param in params: + dtype, shape = from_tensor(param.type) + node = InputNode(f"{param.id.name}", { + "shape": shape, + "dtype": dtype, + }) + self.let_bind(param.id, node) + + # Then we compile the body into a graph which can depend + # on input variables. + output_ref = self.visit(func.body) + + # Finally we retreive return value of program, which will + # become our output node. + self.get_node(output_ref).is_output = True + + def visit_let(self, let: ir.Let) -> NodeRef: + """Visit the Let binding, by first traversing its value, + then setting the metadata on the returned NodeRef. + + Finally visit the body, and return the NodeRef corresponding + to it. + """ + ident = let.id + val = let.value + body = let.body + + # Need to add type info? + val_ref = self.visit(val) + dtype, shape = from_tensor(get_checked_type(val)) + val_node = self.get_node(val_ref) + val_node.attrs["dtype"] = dtype + val_node.attrs["shape"] = shape + self.add_binding(ident, val_ref) + return self.visit(body) + + def visit_local_id(self, ident: ir.LocalId) -> NodeRef: + return self.lookup(ident) + + def visit_call(self, call: ir.Call) -> NodeRef: + inputs = [] + for arg in call.args: + inputs.append(self.visit(arg).to_json()) + + # need to deal with name mangle + op_name = call.fn.name + op_node = OpNode("call_name", {}, op_name, inputs, {}) + return self.add_node(op_node) + + def to_json(self) -> str: + """Convert the sequence of nodes stored by the compiler into the + JSON format defined in: https://docs.tvm.ai/dev/nnvm_json_spec.html. + """ + nodes = [] + # First we compute "nodes" field. + for node in self.nodes: + nodes.append(node.to_json()) + + arg_nodes = [] + heads = [] + # Compute "arg_nodes" and "heads" fields. + for i, node in enumerate(self.nodes): + if isinstance(node, InputNode): + arg_nodes.append(i) + + if node.is_output: + # Need to fix this. + heads.append(NodeRef(i).to_json()) + + # Compute "node_row_ptr". + # TODO + + # Compute "attrs" field. + attrs = {} + + # A + shapes = [] + storage_ids = [] + dtype = [] + dltype = [] + + for i, node in enumerate(self.nodes): + storage_ids.append(i) + shapes.append(node.attrs['shape']) + if node.attrs['dtype'] == 'float32': + dtype.append(0) + dltype.append('float32') + + attrs["shape"] = ["list_shape", shapes] + attrs["storage_id"] = ["list_int", storage_ids] + attrs["dtype"] = ["list_int", dtype] + attrs["dltype"] = ["list_str", dltype] + + json_dict = { + "nodes": nodes, + "arg_nodes": arg_nodes, + "heads": heads, + "attrs": attrs + } + + return json.dumps(json_dict) + + +def compile_to_tvm(func): + """Compile a single function to the components needed by the + TVM RTS. + """ + env = get_env() + iids = [] + + # Why do I need to call items? + for op in env.operators(): + if not Operator_is_generic(op): + iids.append(op.id) + + # TODO(@jroesch): Need to write test case for this + mod = compile_ops_to_module(env, iids) + comp = TVMRTSCompiler() + comp.compile(func) + graph_json = comp.to_json() + return graph_json, mod, None # params currently isn't supported by API From d9d364205ced04cc0ada4e476c05447911f83e49 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 15:33:44 -0700 Subject: [PATCH 40/88] Type checker is working for one op case --- python/tvm/relay/expr.py | 1 + python/tvm/relay/ir_builder.py | 3 - python/tvm/relay/op.py | 44 ----- python/tvm/relay/op/__init__.py | 6 + src/relay/compiler/resolve.cc | 2 + src/relay/compiler/type_infer.cc | 256 ++++++++++--------------- src/relay/op/tensor/elemwise.cc | 4 +- tests/python/relay/test_typechecker.py | 10 +- 8 files changed, 113 insertions(+), 213 deletions(-) delete mode 100644 python/tvm/relay/op.py diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 7f5dcbd0beb5..e98d74f3da88 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -71,6 +71,7 @@ class Function(Expr): def __init__(self, params: List[Param], ret_type: Type, body: Expr, type_params: List[TypeParam]=[]) -> None: self.__init_handle_by_constructor__(_make.Function, params, ret_type, body, type_params) +@register_relay_node class Call(Expr): op: Expr args: List[Expr] diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 07927aef7d24..8bd225bd4de1 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -149,9 +149,6 @@ def get(self): return _mk_let(bindings, self.ret_value) -def op(name): - return _op._create_op(name) - def bool_dtype(): return 'uint1' diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py deleted file mode 100644 index d36a433e1e85..000000000000 --- a/python/tvm/relay/op.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Relay operators""" -from __future__ import absolute_import as _abs - -import sys - -from .._ffi.node import convert_to_node -from . import _make -from ..make import node as _make_node -from .expr import Expr, Call -from .base import register_relay_node - -@register_relay_node -class Op(Expr): - pass - -def _create_op(op_name): - op = _GetOp(op_name) - attrs_type_key = op.attrs_type_key - attrs_type_key = attrs_type_key if attrs_type_key else "DictAttrs" - # TODO(tqchen): improve the code build to fix the restriction. - # - # current restriction: - # - pass in args as positional arguments - # - pass in kwargs as keyword argument - def _op_func(*args, **kwargs): - args = convert_to_node(args) - # Need work to make sure constructor matches - # can support kwargs later - attrs = _make_node(attrs_type_key, **kwargs) - return Call(op, args, None, []) - _op_func.__name__ = op.name - return _op_func - - -def _init_ops(): - """Helper function to initialize the operators - """ - module = sys.modules[__name__] - for name in _ListOpNames(): - f = _create_op(name.value) - setattr(module, f.__name__, f) - - -_init_ops() diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 02e49ec40ff8..ad2f54929aed 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -4,3 +4,9 @@ # operator registry from . import _tensor +from ..expr import Expr +from ..base import register_relay_node + +@register_relay_node +class Op(Expr): + pass diff --git a/src/relay/compiler/resolve.cc b/src/relay/compiler/resolve.cc index 2d3e84dc2160..236722b23387 100644 --- a/src/relay/compiler/resolve.cc +++ b/src/relay/compiler/resolve.cc @@ -53,6 +53,7 @@ struct ResolveTypeExpr : ExprFVisitor<> { // term, then resolve e's old type and write // it back into the new node. auto new_e = ExprFVisitor::VisitExpr(e); + CHECK(e->checked_type_.defined()); auto resolved_cty = VisitType(e->checked_type_); new_e->checked_type_ = resolved_cty; return new_e; @@ -64,6 +65,7 @@ struct ResolveTypeExpr : ExprFVisitor<> { }; Type resolve(const TypeUnifier &unifier, const Type &ty) { + CHECK(ty.defined()); return ResolveTypeType(unifier).VisitType(ty); } diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 40f14b517951..e2e5999e7341 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -27,6 +27,7 @@ #include "./incomplete_type.h" #include "./unifier.h" #include "./resolve.h" +#include "./type_subst.h" // #include "tvm/relay/alpha_eq.h" // #include "tvm/relay/debug.h" // #include "tvm/relay/first_order_reverse_ad.h" @@ -71,6 +72,7 @@ struct CheckedExpr { Expr expr; Type type; CheckedExpr(Expr e, Type t) : expr(e), type(t) {} + CheckedExpr() {} }; class TypeInferencer : private ExprFunctor { @@ -94,7 +96,7 @@ class TypeInferencer : private ExprFunctor { CheckedExpr Infer(const Expr & expr); - Type instantiate(FuncType fn_ty, tvm::Array &ty_args); + FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); void report_error(const std::string & msg, Span sp); [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); @@ -103,7 +105,7 @@ class TypeInferencer : private ExprFunctor { Type resolve(const Type &t); Expr resolve(const Expr &e); CheckedExpr VisitFunction(const Function & f, bool generalize); - // Operator CheckOp(Operator op); + void CheckOp(Op op); // Defn CheckDefn(Defn def); private: CheckedExpr VisitExpr_(const LocalVarNode* op) override; @@ -115,6 +117,7 @@ class TypeInferencer : private ExprFunctor { CheckedExpr VisitExpr_(const CallNode* op) override; CheckedExpr VisitExpr_(const LetNode* op) override; CheckedExpr VisitExpr_(const IfNode* op) override; + CheckedExpr VisitExpr_(const OpNode* op) override; }; TypeInferencer::TypeInferencer() { @@ -145,7 +148,7 @@ class TypeInferencer : private ExprFunctor { // GlobalVar id = GetRef(op); // Item item = this->env->lookup(id); - // if (const OperatorNode *op = item.as()) { + // if (const OpNode *op = item.as()) { // return op->type; // } @@ -167,12 +170,12 @@ class TypeInferencer : private ExprFunctor { TensorTypeNode::make({}, HalideIR::Float(32, 1)) }; } - // Type TypeInferencer::VisitExpr_(const OperatorIdNode *op) { - // OperatorId id = GetRef(op); + // Type TypeInferencer::VisitExpr_(const OpIdNode *op) { + // OpId id = GetRef(op); // Item item = this->env->lookup(id); - // if (const OperatorNode *pn = item.as()) { - // Operator prim = GetRef(pn); + // if (const OpNode *pn = item.as()) { + // Op prim = GetRef(pn); // return prim->type; // } else { // this->fatal_error("internal error in InstrinsicId case", op->span); @@ -344,15 +347,20 @@ class TypeInferencer : private ExprFunctor { return this->VisitFunction(GetRef(op), false); } - Type TypeInferencer::instantiate(FuncType fn_ty, tvm::Array &ty_args) { - // const TypeQuantifierNode *ty_quant; - // while ((ty_quant = t.as())) { - // TypeParam id = ty_quant->id; - // TypeVar fresh = TypeVarNode::make(id->kind); - // this->unifier->insert(fresh); - // ty_args.push_back(fresh); - // t = type_subst(ty_quant->boundType, id, fresh); - // } + FuncType TypeInferencer::instantiate(FuncType fn_ty, tvm::Array &ty_args) { + tvm::Map subst_map; + + // Build a subsitituion map up from the function type and type arguments. + // Eventually allow the type vars to be passed in. + for (auto ty_param : fn_ty->type_params) { + IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); + this->unifier->insert(fresh); + ty_args.push_back(fresh); + subst_map.Set(ty_param, fresh); + } + + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + inst_ty = type_subst(fn_ty, subst_map); // if (!check_kind(t)) { // this->fatal_error("Kind rules broken when instantiating type @@ -360,7 +368,7 @@ class TypeInferencer : private ExprFunctor { // t->span); // } - // return t; + return GetRef(inst_ty.as()); } CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { @@ -369,13 +377,13 @@ class TypeInferencer : private ExprFunctor { auto checked_op = this->Infer(c->op); RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl - << "fn_ty=" << fn_ty << std::endl; + << "fn_ty=" << checked_op.type << std::endl; - auto fn_ty_node = checked_op.expr.as(); + auto fn_ty_node = checked_op.type.as(); if (!fn_ty_node) { - this->fatal_error("only expressions with function types can be called", c->fn->span); + this->fatal_error("only expressions with function types can be called", c->op->span); } // We now have a function type. @@ -389,13 +397,12 @@ class TypeInferencer : private ExprFunctor { fn_ty = instantiate(fn_ty, ty_args); std::vector arg_types; + std::vector checked_args; - - // TODO(sslyu): figure out how to handle type ids - // fn_ty = instantiate(fn_ty, ty_args); for (auto arg : c->args) { auto checked_arg = this->Infer(arg); arg_types.push_back(checked_arg.type); + checked_args.push_back(checked_arg.expr); } auto type_arity = fn_ty->arg_types.size(); @@ -423,164 +430,100 @@ class TypeInferencer : private ExprFunctor { ty_args.Set(i, this->unifier->subst(ty_args[i])); } - // Write the type arguments into the call node, recording what inference - // solves. This solution might need some work. - c->ty_args = ty_args; + auto new_call = CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); - return { new_call, call_type } + return { new_call, fn_ty->ret_type }; } - // Type TypeInferencer::VisitExpr_(const DebugNode *op) { - // return this->Check(op->node); - // } - CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { Let let = GetRef(op); - Type checked_ty; + CheckedExpr checked_value; Type annotated_ty = resolve(let->value_type); - // if we are let-defining a function, treat it as a let-rec and insert - // the id with the annotated type in case there is recursion; - // no such recursion permitted with anything that's not a function! - // if (let->value.as()) { - // with_frame([&]() { - // local_stack.insert(let->id, annotated_ty); - // checked_ty = Check(let->value); - // }); - // } else { - checked_ty = Infer(let->value).type; - // } - // ensure annotated type and checked type are compatible - // TODO(sslyu): should the annotated type override the unified one? + // If we are let-defining a function, we want to be able to + // recursively name the function in order to support recursive + // local definitions. + if (let->value.as()) { + with_frame([&]() { + local_stack.insert(let->var, annotated_ty); + checked_value = Infer(let->value); + }); + } else { + checked_value = Infer(let->value); + } + Type unified_ty = - this->unify(checked_ty, annotated_ty, let->span); + this->unify(checked_value.type, annotated_ty, let->span); + + // Update type context with unified type now that we have + // solved this equation. + local_stack.insert(let->var, unified_ty); - return with_frame([&]() { + auto checked_body = with_frame([&]() { local_stack.insert(let->var, unified_ty); return Infer(let->body); }); - } - // Type TypeInferencer::VisitExpr_(const ReverseNode *op) { - // // apply reverse mode to node and typecheck that instead - // std::shared_ptr gf = std::make_shared(); - // return this->Check(ReverseExpr(env, op->node, gf)); - // } - - // Type TypeInferencer::VisitExpr_(const GradientNode *op) { - // auto node = op->node; - // this->Check(node); - // auto gf = std::make_shared(); - // return FOWithGradientType(node->checked_type()); - // } - - // Type TypeInferencer::VisitExpr_(const ProjectionNode *op) { - // Projection proj = GetRef(op); - - // Type tup_type = this->Check(proj->tuple); - - // const TupleTypeNode *ptn = tup_type.as(); - // if (!ptn) { - // this->fatal_error("Cannot project into non-product type", op->span); - // } + auto checked_let = LetNode::make( + let->var, + checked_value.expr, + checked_body.expr, + let->value_type); - // TupleType pt = GetRef(ptn); - // size_t field = (size_t)proj->field; - // if (field >= pt->fields.size()) { - // this->fatal_error("Projecting past bounds of product", op->span); - // } - - // return pt->fields[field]; - // } + return { checked_let, checked_body.type }; + } CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { - // If ifn = GetRef(op); - - // // Ensure the type of the guard is of Tensor[Bool, ()], - // // that is a rank-0 boolean tensor. - // Type guardType = this->Check(ifn->guard); - // bool is_bool = false; - // bool zero_rank = false; - // if (const TensorTypeNode *ttn = guardType.as()) { - // TensorType tt = GetRef(ttn); - - // if (const BaseTypeNode *btn = tt->dtype.as()) { - // is_bool = btn->type.is_bool(); - // } - - // Type shape = simple_eval_shape(tt->shape); - - // if (const ShapeSeqNode *sn = shape.as()) { - // zero_rank = (sn->shapes.size() == 0); - // } - // } - - // if (!(is_bool && zero_rank)) { - // this->fatal_error("IfNode guard must be a rank 0 bool tensor", - // ifn->guard->span); - // } + If ifn = GetRef(op); + + // Ensure the type of the guard is of Tensor[Bool, ()], + // that is a rank-0 boolean tensor. + auto checked_cond = this->Infer(ifn->cond); + auto cond_type = checked_cond.type; + + if (const TensorTypeNode *tt_node = cond_type.as()) { + TensorType tt = GetRef(tt_node); + if (tt->dtype.is_bool() && tt->shape.size() == 0) { + auto checked_true = this->Infer(ifn->true_value); + auto checked_false = this->Infer(ifn->false_value); + auto unified_type = this->unify(checked_true.type, checked_false.type, ifn->span); + auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, checked_false.expr); + return { checked_if, unified_type }; + } + } - // // unify types of different branches - // Type left = this->Check(ifn->true_b); - // Type right = this->Check(ifn->false_b); - // return this->unify(left, right, ifn->span); + this->fatal_error("if-then-else guard must be a rank-0 boolean tensor", + ifn->cond->span); } - // Type TypeInferencer::VisitExpr_(const RefNode *op) { - // Ref r = GetRef(op); - // Type inner = this->Check(r->expr); - // return RefTypeNode::make(inner); - // } - - // Type TypeInferencer::VisitExpr_(const ReadRefNode *op) { - // ReadRef vr = GetRef(op); - // Type ref_type = this->Check(vr->ref); - - // // reject if not a ref type - // const RefTypeNode *rtn = ref_type.as(); - // if (!rtn) { - // this->fatal_error( - // "the de-reference operation can only be used with references", - // op->span); - // } - - // RefType rt = GetRef(rtn); - // return rt->data_type; - // } - - // Type TypeInferencer::VisitExpr_(const WriteRefNode *op) { - // WriteRef sr = GetRef(op); - // Type ref_type = this->Check(sr->ref); - - // const RefTypeNode *rtn = ref_type.as(); - // if (!rtn) { - // this->fatal_error("Cannot mutate non-ref", op->span); - // } - // RefType rt = GetRef(rtn); - - // // ensure ref type's inner type and expr's type are compatible; return - // unit Type expr_type = this->Check(sr->val); this->unify(rt->data_type, - // expr_type, sr->span); return UnitType(); - // } + CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op) { + return { GetRef(op), FuncTypeNode::make({}, TensorTypeNode::Int(32), {}, {} )}; + } Type TypeInferencer::resolve(const Type &t) { - return ::tvm::relay::resolve(this->unifier, t); + if (t.defined()) { + return ::tvm::relay::resolve(this->unifier, t); + } else { + return IncompleteTypeNode::make(TypeParamNode::Kind::kType); + } } Expr TypeInferencer::resolve(const Expr &e) { + CHECK(e.defined()); return ::tvm::relay::resolve(this->unifier, e); } - // Operator TypeInferencer::CheckOp(Operator op) { - // if (!check_kind(op->type)) { - // report_error("the type of the operator is ill formed", op->type->span); - // } + void TypeInferencer::CheckOp(Op op) { + throw Error("NYI"); + // if (!check_kind(op->type)) { + // report_error("the type of the operator is ill formed", op->type->span); + // } - // // Fix me - // return op; - // } + // // Fix me + // return op; + } // Defn TypeInferencer::CheckDefn(Defn defn) { // // This is to handle recursion, but we need to speculatively @@ -620,8 +563,8 @@ class TypeInferencer : private ExprFunctor { // try { // if (const DefnNode *defn = i.as()) { // return tc.CheckDefn(GetRef(defn)); - // } else if (const OperatorNode *op_node = i.as()) { - // return tc.CheckOp(GetRef(op_node)); + // } else if (const OpNode *op_node = i.as()) { + // return tc.CheckOp(GetRef(op_node)); // } else { // throw dmlc::Error("internal error: unknown Item type"); // } @@ -717,13 +660,6 @@ class TypeInferencer : private ExprFunctor { *ret = Infer(env, e); }); - // TVM_REGISTER_API("relay._tyck.check_item") - // .set_body([](TVMArgs args, TVMRetValue *ret) { - // Environment env = args[0]; - // Item i = args[1]; - // *ret = check(env, i); - // }); - TVM_REGISTER_API("relay._type_infer._get_checked_type") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr e = args[0]; diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 50c864650ff4..d1d3e01ed9a6 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -35,9 +35,7 @@ RELAY_REGISTER_UNARY_OP("log") log(x) )code" TVM_ADD_FILELINE) -.set_support_level(1); -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(1) .add_type_func("Broadcast"); diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index 9c050ecd62d0..d111bba9dfbf 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -2,8 +2,9 @@ for expressions. """ from tvm.relay.type_infer import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type, op, func_type +from tvm.relay.ir_builder import IRBuilder, float_type, func_type from tvm.relay.env import Environment +from tvm.relay.op import log def has_type(expr, typ): env = Environment({}) @@ -23,9 +24,12 @@ def test_monomorphic_let(): def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" b = IRBuilder() - f = op('log') with b.function(('x', float_type())) as func: x, = func.param_ids() - t1 = b.let('t1', f(x)) + t1 = b.let('t1', log(x)) b.ret(t1) assert has_type(func.to_func(), func_type([float_type()], float_type())) + +if __name__ == "__main__": + test_monomorphic_let() + test_single_op() From cd5c1059e64b4a48a2a97477f98b7cb61cdae83b Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 24 Aug 2018 16:03:53 -0700 Subject: [PATCH 41/88] op step 2 --- include/tvm/relay/op.h | 18 +++++-- python/tvm/relay/__init__.py | 8 +-- python/tvm/relay/expr.py | 2 + python/tvm/relay/op/__init__.py | 3 ++ python/tvm/relay/op/op.py | 77 +++++++++++++++++++++++++++++ python/tvm/relay/op/registry.py | 1 - src/relay/ir/op.cc | 62 ++++++++++++++++++----- tests/python/relay/test_relay_op.py | 13 +++++ 8 files changed, 164 insertions(+), 20 deletions(-) create mode 100644 python/tvm/relay/op/op.py delete mode 100644 python/tvm/relay/op/registry.py diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index be81f54ecd69..15c55dee52c0 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -103,7 +103,7 @@ class Op : public relay::Expr { * \tparam ValueType The type of the attribute. */ template - inline static const OpMap& GetAttr(const std::string& attr_name); + inline static OpMap GetAttr(const std::string& attr_name); /*! * \brief Get an Op for a given operator name. * Will raise an error if the op has not been registered. @@ -193,9 +193,13 @@ class OpRegistry { // set the name of the op to be the same as registry inline OpRegistry& set_name() { // NOLINT(*) - get()->name = name; + if (get()->name.length() == 0) { + get()->name = name; + } return *this; } + /*! \return The global single retistry */ + TVM_DLL static ::dmlc::Registry* Registry(); private: friend class ::dmlc::Registry; @@ -307,7 +311,7 @@ class OpMap { */ #define RELAY_REGISTER_OP(OpName) \ DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ - ::dmlc::Registry<::tvm::relay::OpRegistry>::Get()->__REGISTER_OR_GET__(OpName).set_name() + ::tvm::relay::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name() // implementations inline const OpNode* Op::operator->() const { @@ -315,7 +319,7 @@ inline const OpNode* Op::operator->() const { } template -inline const OpMap& Op::GetAttr(const std::string& key) { +inline OpMap Op::GetAttr(const std::string& key) { return OpMap(Op::GetGenericAttr(key)); } @@ -352,6 +356,12 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } +inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) + const std::string& type_key) { + get()->attrs_type_key = type_key; + return *this; +} + inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) get()->support_level = n; return *this; diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index f94b572f6b44..019d7c19a865 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -4,10 +4,6 @@ from . import expr from . import op -# import all operators in the loop namespace -from .op import * - - # Span Span = base.Span @@ -31,3 +27,7 @@ Let = expr.Let If = expr.If Var = LocalVar + +# Operators +from .op import Op +from .op.tensor import * diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index e98d74f3da88..41066829e2f3 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -99,3 +99,5 @@ class If(Expr): def __init__(self, cond: Expr, true_value: Expr, false_value: Expr) -> None: self.__init_handle_by_constructor__(_make.If, cond, true_value, false_value) + + diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index ad2f54929aed..3d87d78fe633 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,5 +1,8 @@ """Relay core operators.""" # operator defs +from .op import get, register, Op + +# Operators from .tensor import * # operator registry diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py new file mode 100644 index 000000000000..4540b19f5ccf --- /dev/null +++ b/python/tvm/relay/op/op.py @@ -0,0 +1,77 @@ +"""The base node types for the Relay language.""" +from ..._ffi.function import _init_api + +from ..base import register_relay_node +from ..expr import Expr +from ..._ffi.function import Function +from ...api import convert + +@register_relay_node +class Op(Expr): + def __init__(self): + raise RuntimeError("Cannot create op, use get instead") + + def get_attr(self, attr_name): + """Get additional attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name. + + Returns + ------- + value : object + The attribute value + """ + return _OpGetAttr(self, attr_name) + + +def get(op_name): + """Get the Op for a given name + + Parameters + ---------- + op_name : str + The operator name + + Returns + ------- + op : Op + The op of the corresponding name + """ + return _GetOp(op_name) + + +def register(op_name, attr_key, value=None, level=10): + """Register an operator property of an operator. + + + Parameters + ---------- + op_name : str + The name of operator + + attr_key : str + The attribute name. + + value : object, optional + The value to set + + level : int, optional + The priority level + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + def _register(v): + """internal register function""" + _Register(op_name, attr_key, v, level) + return v + return _register(value) if value else _register + + +_init_api("relay.op", __name__) + diff --git a/python/tvm/relay/op/registry.py b/python/tvm/relay/op/registry.py deleted file mode 100644 index d7426429ef6f..000000000000 --- a/python/tvm/relay/op/registry.py +++ /dev/null @@ -1 +0,0 @@ -"""Mechanism to work with operator registry.""" diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 5a4241a182b1..664947425b53 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -10,6 +10,10 @@ DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); namespace tvm { namespace relay { +::dmlc::Registry* OpRegistry::Registry() { + return ::dmlc::Registry::Get(); +} + // single manager of operator information. struct OpManager { // mutex to avoid registration from multiple threads. @@ -18,6 +22,8 @@ struct OpManager { std::atomic op_counter{0}; // storage of additional attribute table. std::unordered_map > attr; + // frontend functions + std::vector frontend_funcs; // get singleton of the static OpManager* Global() { static OpManager inst; @@ -75,22 +81,56 @@ void OpRegistry::UpdateAttr( } // Frontend APIs -using runtime::TypedPackedFunc; - TVM_REGISTER_API("relay.op._ListOpNames") -.set_body(TypedPackedFunc()>([]() { - Array ret; - for (const std::string& name : - dmlc::Registry::ListAllNames()) { - ret.push_back(tvm::Expr(name)); - } - return ret; - })); +.set_body_typed()>([]() { + Array ret; + for (const std::string& name : + dmlc::Registry::ListAllNames()) { + ret.push_back(tvm::Expr(name)); + } + return ret; + }); TVM_REGISTER_API("relay.op._GetOp") -.set_body(TypedPackedFunc(Op::Get)); +.set_body_typed(Op::Get); +TVM_REGISTER_API("relay.op._OpGetAttr") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + *rv = op_map[op]; + } + }); + + +TVM_REGISTER_API("relay.op._Register") +.set_body([](TVMArgs args, TVMRetValue* rv) { + std::string op_name = args[0]; + std::string attr_key = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + reg.set_attrs_type_key(value); + } else { + // normal attr table override. + if (args[2].type_code() == kFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = args[2]; + // If we get a function from frontend, avoid deleting it. + OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); + reg.set_attr(attr_key, f, plevel); + } else { + reg.set_attr(attr_key, args[2], plevel); + } + } + }); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_relay_op.py b/tests/python/relay/test_relay_op.py index 4235dd918d93..1f95a3f72c15 100644 --- a/tests/python/relay/test_relay_op.py +++ b/tests/python/relay/test_relay_op.py @@ -1,5 +1,16 @@ from tvm import relay +def test_op_attr(): + log_op = relay.op.get("log") + + @relay.op.register("exp", "ftest") + def test(x): + return x + 1 + + assert log_op.num_inputs == 1 + assert log_op.get_attr("ftest") is None + assert relay.op.get("exp").get_attr("ftest")(1) == 2 + def test_op_level1(): x = relay.Var("x") @@ -11,4 +22,6 @@ def test_op_level1(): if __name__ == "__main__": + test_op_attr() test_op_level1() + From 22aa1ca227afb92957f044697cc0a8d89577590c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 16:11:07 -0700 Subject: [PATCH 42/88] WIP --- include/tvm/relay/compiler/environment.h | 2 +- include/tvm/relay/op.h | 6 +++--- include/tvm/relay/type.h | 22 ++++++++++++---------- src/codegen/spirv/ir_builder.cc | 2 +- src/relay/compiler/alpha_eq.cc | 2 +- src/relay/compiler/type_functor.h | 4 ++-- src/relay/compiler/type_visitor.h | 2 +- src/relay/compiler/unifier.cc | 2 +- src/relay/compiler/unifier.h | 2 +- src/relay/ir/type.cc | 14 +++++++------- src/relay/op/tensor/elemwise.cc | 4 +++- 11 files changed, 33 insertions(+), 29 deletions(-) diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index d5a3ddd73f77..2ec8ca8af933 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -40,7 +40,7 @@ class EnvironmentNode : public RelayNode { private: /*! A map from string names to GlobalIds, ensures global uniqueness. */ tvm::Map global_map_; - tvm::Map type_func_map_; + tvm::Map type_func_map_; // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 15c55dee52c0..630a231ebb54 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -153,7 +153,7 @@ class OpRegistry { * \param ty_func The type function to register for the return type. * \return reference to self. */ - inline OpRegistry& add_type_func(const std::string & type_func_name); + inline OpRegistry& add_type_func(const std::string & type_func_name, TypeRelationFn type_fn); /*! * \brief Set the type key of attributes. @@ -343,8 +343,8 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, return *this; } - inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name) { - auto type_func = TypeFunctionNode::make(type_func_name, 0); + inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { + auto type_func = TypeRelationNode::make(type_func_name, 0); for (auto arg : get()->arguments) { std::cout << arg << std::endl; } diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index ef8c4c71f5b7..68ed411a23ed 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -210,40 +210,42 @@ class FuncTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); +using TypeRelationFn = std::function(const Array&, int)>; + /*! - * \brief Opaque type inference function. + * \brief Opaque type relation, is an input-output relation on types. */ -class TypeFunction; +class TypeRelation; /*! - * \brief TypeFunction container. + * \brief TypeRelation container. * \note This node is not directly serializable. * The type function need to be lookedup in the environment. */ -class TypeFunctionNode : public RelayNode { +class TypeRelationNode : public RelayNode { public: /*! \brief The name of the function */ std::string name; /*! \brief Number of input type arguments, can be -1, which means VarArgs */ int num_args; /*! - * \brief The type function, + * \brief The function on input and output variables which * this is not directly serializable, * need to be looked-up in the environment. */ - mutable std::function& arg_types)> func_; + TypeRelationFn func_; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name", &name); v->Visit("num_args", &num_args); } - TVM_DLL static TypeFunction make(std::string name, int num_args); + TVM_DLL static TypeRelation make(std::string name, int num_args); - static constexpr const char* _type_key = "relay.TypeFunction"; - TVM_DECLARE_NODE_TYPE_INFO(TypeFunctionNode, RelayNode); + static constexpr const char* _type_key = "relay.TypeRelation"; + TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, RelayNode); }; -RELAY_DEFINE_NODE_REF(TypeFunction, TypeFunctionNode, Type); +RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, Type); /*! * \brief Call a type function with some number of arguments. diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 41cb48c5854b..87987dbf08e9 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -41,7 +41,7 @@ void IRBuilder::InitPreDefs() { t_void_.id = id_counter_++; ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_); t_void_func_.id = id_counter_++; - ib_.Begin(spv::OpTypeFunction) + ib_.Begin(spv::OpTypeRelation) .AddSeq(t_void_func_, t_void_).Commit(&global_); } diff --git a/src/relay/compiler/alpha_eq.cc b/src/relay/compiler/alpha_eq.cc index 688a93ae73fc..d4f1d888fb69 100644 --- a/src/relay/compiler/alpha_eq.cc +++ b/src/relay/compiler/alpha_eq.cc @@ -92,7 +92,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeFunctionNode *op, const Type &t2) override { + void VisitType_(const TypeRelationNode *op, const Type &t2) override { } // void VisitType_(const TupleTypeNode *op, const Type &t2) override { // if (const TupleTypeNode *pt = t2.as()) { diff --git a/src/relay/compiler/type_functor.h b/src/relay/compiler/type_functor.h index 3840c902bfe8..5de56837ca10 100644 --- a/src/relay/compiler/type_functor.h +++ b/src/relay/compiler/type_functor.h @@ -63,7 +63,7 @@ class TypeFunctor { virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeFunctionNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; @@ -81,7 +81,7 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeFunctionNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); return vtable; diff --git a/src/relay/compiler/type_visitor.h b/src/relay/compiler/type_visitor.h index 60ae810a6b96..c98ff3ab8958 100644 --- a/src/relay/compiler/type_visitor.h +++ b/src/relay/compiler/type_visitor.h @@ -47,7 +47,7 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { } } - void VisitType_(const TypeFunctionNode* op, Args... args) override {} + void VisitType_(const TypeRelationNode* op, Args... args) override {} void VisitType_(const IncompleteTypeNode* op, Args... args) override {} }; diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc index b7cc296cc5db..2f728a104530 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/compiler/unifier.cc @@ -325,7 +325,7 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { // throw UnificationError("Cannot unify TupleTypeNode"); // } -Type TypeUnifierNode::VisitType_(const TypeFunctionNode *sen1, const Type t2) { +Type TypeUnifierNode::VisitType_(const TypeRelationNode *sen1, const Type t2) { // ShapeExtension sh_ext1 = GetRef(sen1); // if (const IncompleteTypeNode *tvn2 = t2.as()) { diff --git a/src/relay/compiler/unifier.h b/src/relay/compiler/unifier.h index 86ffd664a161..40583b16a55a 100644 --- a/src/relay/compiler/unifier.h +++ b/src/relay/compiler/unifier.h @@ -110,7 +110,7 @@ class TypeUnifierNode : public Node, Type VisitType_(const TypeParamNode* t1, const Type t2) override; Type VisitType_(const FuncTypeNode* t1, const Type t2) override; // Type VisitType_(const TupleTypeNode* t1, const Type t2) override; - Type VisitType_(const TypeFunctionNode* s1, const Type t2) override; + Type VisitType_(const TypeRelationNode* s1, const Type t2) override; Type VisitType_(const TypeCallNode* s1, const Type t2) override; }; diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 2b6647a5807e..d9e2737225ec 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -116,22 +116,22 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->type_constraints << ")"; }); -TypeFunction TypeFunctionNode::make(std::string name, int num_args) { - std::shared_ptr n = std::make_shared(); +TypeRelation TypeRelationNode::make(std::string name, int num_args) { + std::shared_ptr n = std::make_shared(); n->name = std::move(name); n->num_args = std::move(num_args); - return TypeFunction(n); + return TypeRelation(n); } -TVM_REGISTER_API("relay._make.TypeFunction") +TVM_REGISTER_API("relay._make.TypeRelation") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TypeFunctionNode::make(args[0], args[1]); + *ret = TypeRelationNode::make(args[0], args[1]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const TypeFunctionNode *node, + .set_dispatch([](const TypeRelationNode *node, tvm::IRPrinter *p) { - p->stream << "TypeFunctionNode(" << node->name << ", " << node->num_args << ")"; + p->stream << "TypeRelationNode(" << node->name << ", " << node->num_args << ")"; }); TypeCall TypeCallNode::make(Type func, Array args) { diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index d1d3e01ed9a6..05e1cbd57b13 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -36,7 +36,9 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Broadcast"); +.add_type_func("Log", [](const Array & t, int num_args) { + return t; +}); RELAY_REGISTER_UNARY_OP("exp") From 88dde8b34d84f821a168d8b6e5133e4e72e8a4a1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 26 Aug 2018 21:54:04 -0700 Subject: [PATCH 43/88] Address comments from Friday --- .../tvm/relay/{compiler => }/environment.h | 27 +++++++++---------- include/tvm/relay/expr_functor.h | 3 ++- include/tvm/relay/ir.h | 20 -------------- .../tvm/relay/{compiler => pass}/alpha_eq.h | 9 ++++--- .../tvm/relay/{compiler => pass}/type_infer.h | 14 +++++----- python/tvm/relay/ir.py | 18 ------------- src/relay/{compiler => ir}/environment.cc | 2 +- src/relay/{compiler => pass}/alpha_eq.cc | 8 +++--- .../{compiler => pass}/incomplete_type.h | 8 +++--- src/relay/{compiler => pass}/resolve.cc | 10 +++---- src/relay/{compiler => pass}/resolve.h | 6 ++--- src/relay/{compiler => pass}/type_functor.h | 8 +++--- src/relay/{compiler => pass}/type_infer.cc | 12 ++++----- src/relay/{compiler => pass}/type_subst.cc | 12 ++++----- src/relay/{compiler => pass}/type_subst.h | 8 +++--- src/relay/{compiler => pass}/type_visitor.h | 0 src/relay/{compiler => pass}/unifier.cc | 22 ++++++++------- src/relay/{compiler => pass}/unifier.h | 10 +++---- 18 files changed, 80 insertions(+), 117 deletions(-) rename include/tvm/relay/{compiler => }/environment.h (83%) delete mode 100644 include/tvm/relay/ir.h rename include/tvm/relay/{compiler => pass}/alpha_eq.h (52%) rename include/tvm/relay/{compiler => pass}/type_infer.h (69%) delete mode 100644 python/tvm/relay/ir.py rename src/relay/{compiler => ir}/environment.cc (99%) rename src/relay/{compiler => pass}/alpha_eq.cc (97%) rename src/relay/{compiler => pass}/incomplete_type.h (82%) rename src/relay/{compiler => pass}/resolve.cc (92%) rename src/relay/{compiler => pass}/resolve.h (79%) rename src/relay/{compiler => pass}/type_functor.h (95%) rename src/relay/{compiler => pass}/type_infer.cc (98%) rename src/relay/{compiler => pass}/type_subst.cc (66%) rename src/relay/{compiler => pass}/type_subst.h (54%) rename src/relay/{compiler => pass}/type_visitor.h (100%) rename src/relay/{compiler => pass}/unifier.cc (96%) rename src/relay/{compiler => pass}/unifier.h (95%) diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/environment.h similarity index 83% rename from include/tvm/relay/compiler/environment.h rename to include/tvm/relay/environment.h index 2ec8ca8af933..ff8e596059b5 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/environment.h @@ -1,18 +1,17 @@ /*! * Copyright (c) 2018 by Contributors - * \file environment.h - * \brief The global environment containing + * \file tvm/relay/environment.h + * \brief The global environment, contains global state of Relay program. */ #ifndef TVM_RELAY_ENVIRONMENT_H_ #define TVM_RELAY_ENVIRONMENT_H_ #include #include -#include "../expr.h" -#include "../type.h" -#include "../op.h" -#include "../error.h" -// #include "tvm/relay/options.h" +#include "./expr.h" +#include "./type.h" +#include "./op.h" +#include "./error.h" // #include "tvm/relay/source_map.h" namespace tvm { @@ -38,10 +37,8 @@ struct Environment; class EnvironmentNode : public RelayNode { private: - /*! A map from string names to GlobalIds, ensures global uniqueness. */ + /*! \brief A map from string names to global variables ensures global uniqueness. */ tvm::Map global_map_; - tvm::Map type_func_map_; - // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ // /*! \brief A list of the errors reported during the current run. */ @@ -51,8 +48,6 @@ class EnvironmentNode : public RelayNode { /*! \brief A map from ids to all global functions. */ tvm::Map items; - // Options options; - EnvironmentNode() {} void VisitAttrs(tvm::AttrVisitor* v) final {} @@ -67,15 +62,17 @@ class EnvironmentNode : public RelayNode { GlobalVar GetGlobalVar(const std::string& str); - /*! \brief Lookup a global function by its name. */ + /*! \brief Lookup a global function by its variable. */ Function Lookup(const GlobalVar& id); + + /*! \brief Lookup a global function by its string name */ Function Lookup(const std::string & s); /*! \brief Add a source fragment to the environment. */ // FileId add_source(std::string file_name, std::string source); - void report_error(std::string msg, Span sp); - void display_errors(); + void ReportError(std::string msg, Span sp); + void DisplayErrors(); static constexpr const char* _type_key = "relay.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 2067b90bd364..e37a454eee41 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -9,7 +9,8 @@ #include #include -#include "ir.h" +#include "./expr.h" +#include "./op.h" namespace tvm { namespace relay { diff --git a/include/tvm/relay/ir.h b/include/tvm/relay/ir.h deleted file mode 100644 index 73c275cf1c98..000000000000 --- a/include/tvm/relay/ir.h +++ /dev/null @@ -1,20 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/ir.h - * \brief The Relay intermediate representation's core data structures. - */ -#ifndef TVM_RELAY_IR_H_ -#define TVM_RELAY_IR_H_ - -#include "./base.h" -#include "./type.h" -#include "./expr.h" -#include "./op.h" - -// namespace tvm { -// namespace relay { - -// } // namespace relay -// } // namespace tvm - -#endif // TVM_RELAY_IR_H_ diff --git a/include/tvm/relay/compiler/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h similarity index 52% rename from include/tvm/relay/compiler/alpha_eq.h rename to include/tvm/relay/pass/alpha_eq.h index ba91afc21015..caa2f93c31a7 100644 --- a/include/tvm/relay/compiler/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -1,18 +1,19 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/alpha_eq.h - * \brief Check expressions & types for structural equivalence. + * \brief Check expressions and types for structural equivalence. */ #ifndef TVM_RELAY_ALPHA_EQ_H_ #define TVM_RELAY_ALPHA_EQ_H_ -#include "tvm/relay/ir.h" +#include "tvm/relay/type.h" +#include "tvm/relay/expr.h" namespace tvm { namespace relay { -bool alpha_eq(const Expr & e1, const Expr & e2); -bool alpha_eq(const Type & t1, const Type & t2); +bool AlphaEqual(const Expr & e1, const Expr & e2); +bool AlphaEqual(const Type & t1, const Type & t2); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/compiler/type_infer.h b/include/tvm/relay/pass/type_infer.h similarity index 69% rename from include/tvm/relay/compiler/type_infer.h rename to include/tvm/relay/pass/type_infer.h index c084fb7a109e..9a8ab2bc6a8b 100644 --- a/include/tvm/relay/compiler/type_infer.h +++ b/include/tvm/relay/pass/type_infer.h @@ -1,16 +1,16 @@ /*! * Copyright (c) 2018 by Contributors - * \file tvm/relay/type_infer.h + * \file tvm/relay/pass/type_infer.h * \brief Perform type inference and checking on Relay programs. * * The pass produces a new expression with its checked_type * field populated and incomplete types resolved. */ -#ifndef TVM_RELAY_COMPILER_TYPECHECKER_H_ -#define TVM_RELAY_COMPILER_TYPECHECKER_H_ +#ifndef TVM_RELAY_PASS__TYPECHECKER_H_ +#define TVM_RELAY_PASS__TYPECHECKER_H_ -#include "tvm/relay/ir.h" -#include "tvm/relay/compiler/environment.h" +#include "tvm/relay/expr.h" +#include "tvm/relay/environment.h" namespace tvm { namespace relay { @@ -19,7 +19,7 @@ namespace relay { * with unambigous type information filled in, as well as it's * checked type field populated with the result type. */ -Expr Infer(const Environment & env, const Expr & e); +Expr InferType(const Environment & env, const Expr & e); /*! \brief Ensures that an operator is well-formed with respect * to Relay's type system. @@ -28,4 +28,4 @@ Op CheckOp(const Environment & env, const Op & op); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_COMPILER_TYPECHECKER_H_ +#endif // TVM_RELAY_PASS_TYPECHECKER_H_ diff --git a/python/tvm/relay/ir.py b/python/tvm/relay/ir.py deleted file mode 100644 index a95f29abe6de..000000000000 --- a/python/tvm/relay/ir.py +++ /dev/null @@ -1,18 +0,0 @@ -from . import base -from . import type as ty -from . import expr - -# Base -register_relay_node = base.register_relay_node -NodeBase = base.NodeBase - -# Type -Type = ty.Type -TensorType = ty.Type -Kind = ty.Kind -TypeParam = ty.TypeParam -TypeConstraint = ty.TypeConstraint -FuncType = ty.FuncType -IncompleteType = ty.IncompleteType - -# Expr diff --git a/src/relay/compiler/environment.cc b/src/relay/ir/environment.cc similarity index 99% rename from src/relay/compiler/environment.cc rename to src/relay/ir/environment.cc index a1c6b31076e3..8c155e3bc1bd 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/ir/environment.cc @@ -4,7 +4,7 @@ * \brief The global environment in Relay. */ #include -#include "tvm/relay/compiler/environment.h" +#include "tvm/relay/environment.h" // #include "tvm/relay/alpha_eq.h" // #include "tvm/relay/debug.h" // #include "tvm/relay/typeck/typechecker.h" diff --git a/src/relay/compiler/alpha_eq.cc b/src/relay/pass/alpha_eq.cc similarity index 97% rename from src/relay/compiler/alpha_eq.cc rename to src/relay/pass/alpha_eq.cc index d4f1d888fb69..5247bb5beaef 100644 --- a/src/relay/compiler/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -1,9 +1,9 @@ /*! * Copyright (c) 2018 by Contributors - * \file alpha_eq.cc + * \file src/tvm/relay/pass/alpha_eq.cc * \brief Compute the set of variables not bound in the expression. */ -#include "tvm/relay/compiler/alpha_eq.h" +#include "tvm/relay/pass/alpha_eq.h" #include "tvm/relay/expr_visitor.h" #include "./type_visitor.h" @@ -134,7 +134,7 @@ struct TypeAlphaEq : TypeVisitor { // } }; -bool alpha_eq(const Type &t1, const Type &t2) { +bool AlphaEqual(const Type &t1, const Type &t2) { TypeAlphaEq aeq; aeq.VisitType(t1, t2); return aeq.equal; @@ -277,7 +277,7 @@ TVM_REGISTER_API("relay._make._type_alpha_eq") .set_body([](TVMArgs args, TVMRetValue *ret) { Type t1 = args[0]; Type t2 = args[1]; - *ret = alpha_eq(t1, t2); + *ret = AlphaEqual(t1, t2); }); } // namespace relay diff --git a/src/relay/compiler/incomplete_type.h b/src/relay/pass/incomplete_type.h similarity index 82% rename from src/relay/compiler/incomplete_type.h rename to src/relay/pass/incomplete_type.h index f31a2efdf78d..3967b4e58657 100644 --- a/src/relay/compiler/incomplete_type.h +++ b/src/relay/pass/incomplete_type.h @@ -4,10 +4,10 @@ * \brief A way to defined arbitrary function signature with dispatch on types. */ -#ifndef TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H -#define TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H +#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H +#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H -#include "tvm/relay/ir.h" +#include namespace tvm { namespace relay { @@ -37,4 +37,4 @@ RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H +#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H diff --git a/src/relay/compiler/resolve.cc b/src/relay/pass/resolve.cc similarity index 92% rename from src/relay/compiler/resolve.cc rename to src/relay/pass/resolve.cc index 236722b23387..e86368854060 100644 --- a/src/relay/compiler/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -1,18 +1,18 @@ /*! * Copyright (c) 2018 by Contributors - * \file unifier.cc - * \brief Data structures for type unification + * \file resolve.cc + * \brief Resolve incomplete types to complete types. */ +#include +#include #include "./resolve.h" #include "./type_visitor.h" -#include "tvm/relay/expr_visitor.h" -#include "tvm/relay/ir.h" namespace tvm { namespace relay { -// We should probably generalize the subst code. +// TODO(@jroesch): We should probably generalize the subst code. struct ResolveTypeType : TypeFVisitor { const TypeUnifier &unifier; diff --git a/src/relay/compiler/resolve.h b/src/relay/pass/resolve.h similarity index 79% rename from src/relay/compiler/resolve.h rename to src/relay/pass/resolve.h index b4e164df6287..5f6cc328a239 100644 --- a/src/relay/compiler/resolve.h +++ b/src/relay/pass/resolve.h @@ -1,13 +1,13 @@ /*! * Copyright (c) 2018 by Contributors - * \file tvm/relay/options.h - * \brief Global options for the Relay IR. + * \file tvm/relay/resolve.h + * \brief Resolve incomplete types to complete types. */ #ifndef TVM_RELAY_TYPECK_RESOLVE_H_ #define TVM_RELAY_TYPECK_RESOLVE_H_ #include -#include "tvm/relay/ir.h" +#include #include "./unifier.h" namespace tvm { diff --git a/src/relay/compiler/type_functor.h b/src/relay/pass/type_functor.h similarity index 95% rename from src/relay/compiler/type_functor.h rename to src/relay/pass/type_functor.h index 5de56837ca10..9adc1a08860e 100644 --- a/src/relay/compiler/type_functor.h +++ b/src/relay/pass/type_functor.h @@ -3,11 +3,11 @@ * \file type_functor.h * \brief A way to defined arbitrary function signature with dispatch on types. */ -#ifndef TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ -#define TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ +#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_ +#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #include -#include "tvm/relay/ir.h" +#include #include "./incomplete_type.h" namespace tvm { @@ -90,4 +90,4 @@ class TypeFunctor { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ +#endif // TVM_RELAY_PASS_TYPE_FUNCTOR_H_ diff --git a/src/relay/compiler/type_infer.cc b/src/relay/pass/type_infer.cc similarity index 98% rename from src/relay/compiler/type_infer.cc rename to src/relay/pass/type_infer.cc index e2e5999e7341..d84b3f96d426 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -20,10 +20,10 @@ * constraints we will trigger an error. */ -#include "tvm/relay/logging.h" -#include "tvm/relay/compiler/type_infer.h" -#include "tvm/relay/error.h" -#include "tvm/relay/expr_functor.h" +#include +#include +#include +#include #include "./incomplete_type.h" #include "./unifier.h" #include "./resolve.h" @@ -335,7 +335,7 @@ class TypeInferencer : private ExprFunctor { // auto fresh_tid = // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); // fn_type = - // type_subst(fn_type, GetRef(ty_param_node), fresh_tid); + // TypeSubst(fn_type, GetRef(ty_param_node), fresh_tid); // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); // } // } @@ -360,7 +360,7 @@ class TypeInferencer : private ExprFunctor { } Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); - inst_ty = type_subst(fn_ty, subst_map); + inst_ty = TypeSubst(fn_ty, subst_map); // if (!check_kind(t)) { // this->fatal_error("Kind rules broken when instantiating type diff --git a/src/relay/compiler/type_subst.cc b/src/relay/pass/type_subst.cc similarity index 66% rename from src/relay/compiler/type_subst.cc rename to src/relay/pass/type_subst.cc index 6650f59bad51..91713976bcaa 100644 --- a/src/relay/compiler/type_subst.cc +++ b/src/relay/pass/type_subst.cc @@ -9,10 +9,10 @@ namespace tvm { namespace relay { -struct TypeSubst : TypeFVisitor { +struct TypeSubstV : TypeFVisitor { tvm::Map subst_map; - explicit TypeSubst(tvm::Map subst_map) + explicit TypeSubstV(tvm::Map subst_map) : subst_map(subst_map) {} Type VisitType_(const TypeParamNode *op) override { @@ -25,13 +25,13 @@ struct TypeSubst : TypeFVisitor { } }; -Type type_subst(const Type &type, const TypeParam &target, const Type &subst) { - TypeSubst ty_sub({ {target, subst} }); +Type TypeSubst(const Type &type, const TypeParam &target, const Type &subst) { + TypeSubstV ty_sub({ {target, subst} }); return ty_sub.VisitType(type); } -Type type_subst(const Type &type, tvm::Map subst_map) { - TypeSubst ty_sub(subst_map); +Type TypeSubst(const Type &type, tvm::Map subst_map) { + TypeSubstV ty_sub(subst_map); return ty_sub.VisitType(type); } diff --git a/src/relay/compiler/type_subst.h b/src/relay/pass/type_subst.h similarity index 54% rename from src/relay/compiler/type_subst.h rename to src/relay/pass/type_subst.h index 0bf0de5a4b85..3c248fdce3b7 100644 --- a/src/relay/compiler/type_subst.h +++ b/src/relay/pass/type_subst.h @@ -1,18 +1,18 @@ /*! * Copyright (c) 2018 by Contributors * \file typeck/type_subst.h - * \brief Utility function for substituting types + * \brief Utility functions for substituting types. */ #ifndef TVM_RELAY_TYPECK_TYPE_SUBST_H_ #define TVM_RELAY_TYPECK_TYPE_SUBST_H_ -#include "tvm/relay/ir.h" +#include namespace tvm { namespace relay { -Type type_subst(const Type & type, const TypeParam & target, const Type & subst); -Type type_subst(const Type &type, tvm::Map subst_map); +Type TypeSubst(const Type & type, const TypeParam & target, const Type & subst); +Type TypeSubst(const Type &type, tvm::Map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/compiler/type_visitor.h b/src/relay/pass/type_visitor.h similarity index 100% rename from src/relay/compiler/type_visitor.h rename to src/relay/pass/type_visitor.h diff --git a/src/relay/compiler/unifier.cc b/src/relay/pass/unifier.cc similarity index 96% rename from src/relay/compiler/unifier.cc rename to src/relay/pass/unifier.cc index 2f728a104530..c6a4e7dfba6d 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -1,12 +1,14 @@ /*! * Copyright (c) 2018 by Contributors - * \file unifier.cc - * \brief Data structures for type unification + * \file tvm/src/relay/pass/unifier.cc + * \brief The type unifier which solves a system of equations between + * incomplete types. */ -#include "tvm/relay/ir.h" -#include "tvm/relay/logging.h" -#include "tvm/relay/compiler/alpha_eq.h" +#include +#include +#include +#include #include "./unifier.h" #include "./type_visitor.h" //#include "./type_subst.h" @@ -32,8 +34,8 @@ void UnionFindNode::debug() { } } -void UnionFindNode::assertAlphaEq(const Type & l, const Type & r) { - if (!alpha_eq(l, r)) { +void UnionFindNode::AssertAlphaEqual(const Type & l, const Type & r) { + if (!AlphaEqual(l, r)) { std::stringstream ss; ss << "Incompatible parent types in UF:" << l << " and " << r; throw UnionFindError(ss.str()); @@ -71,7 +73,7 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { } // if both parents are not type vars themselves, check alpha-equality - assertAlphaEq(parent1, parent2); + AssertAlphaEqual(parent1, parent2); return; } @@ -83,7 +85,7 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { return; } - assertAlphaEq(parent1, t); + AssertAlphaEqual(parent1, t); } Type UnionFindNode::find(const IncompleteType &v) { @@ -274,7 +276,7 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { if (const TensorTypeNode *ttn2 = rt2.as()) { TensorType tt2 = GetRef(ttn2); - if (!alpha_eq(tt1, tt2)) { + if (!AlphaEqual(tt1, tt2)) { throw UnificationError("dtypes do not match"); } diff --git a/src/relay/compiler/unifier.h b/src/relay/pass/unifier.h similarity index 95% rename from src/relay/compiler/unifier.h rename to src/relay/pass/unifier.h index 40583b16a55a..aecc428cb6a9 100644 --- a/src/relay/compiler/unifier.h +++ b/src/relay/pass/unifier.h @@ -1,15 +1,15 @@ /*! * Copyright (c) 2018 by Contributors - * \file unifier.h + * \file include/tvm/relay/pass/unifier.h * \brief The type unifier which solves a system of equations between * incomplete types. */ -#ifndef TVM_RELAY_COMPILER_UNIFIER_H_ -#define TVM_RELAY_COMPILER_UNIFIER_H_ +#ifndef TVM_RELAY_PASS_UNIFIER_H_ +#define TVM_RELAY_PASS_UNIFIER_H_ #include +#include #include "./type_functor.h" -#include "tvm/relay/ir.h" namespace tvm { namespace relay { @@ -50,7 +50,7 @@ class UnionFindNode : public Node { void debug(); - void assertAlphaEq(const Type& l, const Type& r); + void AssertAlphaEqual(const Type& l, const Type& r); static constexpr const char* _type_key = "relay.UnionFind"; TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, Node); From 4d6f6e3b43109746448495368cd5593ace993eb9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 26 Aug 2018 21:55:56 -0700 Subject: [PATCH 44/88] Repair tests --- python/tvm/relay/op/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 3d87d78fe633..d54f47e25197 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -9,7 +9,3 @@ from . import _tensor from ..expr import Expr from ..base import register_relay_node - -@register_relay_node -class Op(Expr): - pass From 0aedf89b2b0037e8d3b9019211581384e255cb6f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Aug 2018 16:34:55 -0700 Subject: [PATCH 45/88] Work on type relation --- include/tvm/relay/op.h | 19 +++++++- python/tvm/relay/op/tensor.py | 17 +++++++ src/relay/op/tensor/elemwise.cc | 25 ++++++++--- src/relay/op/type_relations.cc | 45 +++++++++++++++++++ src/relay/op/type_relations.h | 22 +++++++++ src/relay/pass/alpha_eq.cc | 44 ++++++++++-------- src/relay/pass/type_infer.cc | 32 +++++++------ src/relay/pass/type_visitor.h | 4 ++ ...ecker.py => test_tyck_eval_integration.py} | 16 ++++++- 9 files changed, 184 insertions(+), 40 deletions(-) create mode 100644 src/relay/op/type_relations.cc create mode 100644 src/relay/op/type_relations.h rename tests/python/relay/{test_typechecker.py => test_tyck_eval_integration.py} (65%) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 630a231ebb54..c91955460f82 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -345,9 +345,26 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { auto type_func = TypeRelationNode::make(type_func_name, 0); + + std::vector type_params; + std::vector arg_types; + // TODO (@jroesch: revise type generation strategy + int i = 0; for (auto arg : get()->arguments) { - std::cout << arg << std::endl; + std::string name = "t"; + name += std::to_string(i++); + auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); + type_params.push_back(param); + arg_types.push_back(param); } + + + auto type_result = TypeCallNode::make(type_func, arg_types); + + auto func_type = FuncTypeNode::make(arg_types, type_result, type_params, {}); + + get()->op_type = func_type; + return *this; } diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 7155db3a4cd5..aa9ce6bf42e9 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -58,3 +58,20 @@ def sqrt(data): The computed result. """ return _make.sqrt(data) + +def add(lhs, rhs): + """Take sqrt of data. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.add(lhs, rhs) diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 05e1cbd57b13..700e9185ccba 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -5,6 +5,7 @@ */ #include #include +#include "../type_relations.h" namespace tvm { namespace relay { @@ -36,9 +37,7 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Log", [](const Array & t, int num_args) { - return t; -}); +.add_type_func("Log", IdentityRel); RELAY_REGISTER_UNARY_OP("exp") @@ -48,7 +47,8 @@ RELAY_REGISTER_UNARY_OP("exp") \exp(x) )code" TVM_ADD_FILELINE) -.set_support_level(1); +.set_support_level(1) +.add_type_func("Exp", IdentityRel); RELAY_REGISTER_UNARY_OP("sqrt") @@ -58,7 +58,22 @@ RELAY_REGISTER_UNARY_OP("sqrt") sqrt(x) )code" TVM_ADD_FILELINE) -.set_support_level(1); +.set_support_level(1) +.add_type_func("Sqrt", IdentityRel); + +// Addition +TVM_REGISTER_API("relay.op._make.add") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("add"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("add") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_func("Broadcast", BroadcastRel); } // namespace relayv } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc new file mode 100644 index 000000000000..a5ba1dc14b5f --- /dev/null +++ b/src/relay/op/type_relations.cc @@ -0,0 +1,45 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_relations.cc + * \brief A set of utilities and common functionality + * for type relations. + */ +#include +#include +#include "../pass/incomplete_type.h" + +namespace tvm { +namespace relay { + +TensorType as_ttype(const Type & t) { + if (auto tt_node = t.as()) { + return GetRef(tt_node); + } else { + return TensorType(nullptr); + } +} + +Array IdentityRel(const Array & types, int num_args) { + CHECK(types.size() == 1); + auto t1 = as_ttype(types[0]); + if (t1 && types[1].as()) { + return {t1, t1}; + } else { + return types; + } +} + +Array BroadcastRel(const Array & types, int num_args) { + std::cout << "Inside of Broadcast" << std::endl; + CHECK(types.size() == 0); + if (auto t1 = as_ttype(types[0])) { + if (auto t2 = as_ttype(types[1])) { + return types; + } + } + return types; +} + + +} // namespace relayv +} // namespace tvm diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h new file mode 100644 index 000000000000..f2c4876705b6 --- /dev/null +++ b/src/relay/op/type_relations.h @@ -0,0 +1,22 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op/type_relations.h + * \brief A set of utilities and common functionality + * for type relations. + */ +#ifndef TVM_RELAY_TYPECK_RESOLVE_H_ +#define TVM_RELAY_TYPECK_RESOLVE_H_ + +#include +#include + +namespace tvm { +namespace relay { + +Array IdentityRel(const Array & types, int num_args); +Array BroadcastRel(const Array & types, int num_args); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TYPECK_RESOLVE_H_ diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 5247bb5beaef..555d4f2db99d 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -92,8 +92,14 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeRelationNode *op, const Type &t2) override { + void VisitType_(const TypeRelationNode *tr1, const Type &t2) override { + if (const TypeRelationNode *tr2 = t2.as()) { + equal = tr1 == tr2; + } else { + equal = false; + } } + // void VisitType_(const TupleTypeNode *op, const Type &t2) override { // if (const TupleTypeNode *pt = t2.as()) { // if (op->fields.size() != pt->fields.size()) { @@ -112,26 +118,26 @@ struct TypeAlphaEq : TypeVisitor { // } // } -// void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { -// TypeCall tycall = GetRef(tyn1); -// if (const TypeCallNode *tyn2 = t2.as()) { -// if (tycall->func != tyn2->func) { -// equal = false; -// return; -// } + void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { + TypeCall tycall = GetRef(tyn1); + if (const TypeCallNode *tyn2 = t2.as()) { + if (tycall->func != tyn2->func) { + equal = false; + return; + } -// if (tycall->args.size() != tyn2->args.size()) { -// equal = false; -// return; -// } + if (tycall->args.size() != tyn2->args.size()) { + equal = false; + return; + } -// for (size_t i = 0U; i < tycall->args.size(); i++) { -// this->VisitType(tycall->args[i], tyn2->args[i]); -// } -// } else { -// equal = false; -// } -// } + for (size_t i = 0U; i < tycall->args.size(); i++) { + this->VisitType(tycall->args[i], tyn2->args[i]); + } + } else { + equal = false; + } + } }; bool AlphaEqual(const Type &t1, const Type &t2) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index d84b3f96d426..b9cfd5837c4b 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -28,17 +28,8 @@ #include "./unifier.h" #include "./resolve.h" #include "./type_subst.h" -// #include "tvm/relay/alpha_eq.h" -// #include "tvm/relay/debug.h" -// #include "tvm/relay/first_order_reverse_ad.h" -// #include "tvm/relay/free_type_vars.h" -// #include "tvm/relay/gen_fresh.h" -// #include "tvm/relay/ir.h" -// #include "tvm/relay/pretty_printer.h" -// #include "tvm/relay/reverse_ad.h" -// #include "tvm/relay/type_visitor.h" +#include "./type_visitor.h" // #include "tvm/relay/typeck/kindchecker.h" -// #include "tvm/relay/typeck/shape_evaluator.h" namespace tvm { namespace relay { @@ -68,6 +59,12 @@ struct TypeContext { }; }; +struct TypeNormalizer : TypeFVisitor { + TypeUnifier unifier; + TypeNormalizer(const TypeUnifier & unifier) : unifier(unifier) {} + // Type VisitType_( +}; + struct CheckedExpr { Expr expr; Type type; @@ -98,6 +95,8 @@ class TypeInferencer : private ExprFunctor { FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); + Type Normalize(const Type & t); + void report_error(const std::string & msg, Span sp); [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); @@ -129,11 +128,17 @@ class TypeInferencer : private ExprFunctor { this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); } + Type TypeInferencer::Normalize(const Type & t) { + auto nt = this->resolve(t); + auto normalizer = TypeNormalizer(this->unifier); + return normalizer.VisitType(nt); + } + CheckedExpr TypeInferencer::Infer(const Expr &expr) { RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; CheckedExpr checked_expr = this->VisitExpr(expr); RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type << std::endl; - Type final_type = this->unifier->subst(checked_expr.type); + Type final_type = Normalize(checked_expr.type); RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type << std::endl; checked_expr.expr->checked_type_ = final_type; return checked_expr; @@ -498,8 +503,9 @@ class TypeInferencer : private ExprFunctor { ifn->cond->span); } - CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op) { - return { GetRef(op), FuncTypeNode::make({}, TensorTypeNode::Int(32), {}, {} )}; + CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { + auto op = GetRef(op_node); + return { op, op->op_type }; } Type TypeInferencer::resolve(const Type &t) { diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index c98ff3ab8958..68dba76644c3 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -85,6 +85,10 @@ struct TypeFVisitor : TypeFunctor { // return TupleTypeNode::make(new_fields); // } + Type VisitType_(const TypeRelationNode* op) override { + return GetRef(op); + } + Type VisitType_(const TypeCallNode* op) override { auto func = this->VisitType(op->func); std::vector new_args; diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_tyck_eval_integration.py similarity index 65% rename from tests/python/relay/test_typechecker.py rename to tests/python/relay/test_tyck_eval_integration.py index d111bba9dfbf..d96682fbfda4 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -4,11 +4,12 @@ from tvm.relay.type_infer import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, func_type from tvm.relay.env import Environment -from tvm.relay.op import log +from tvm.relay.op import log, add def has_type(expr, typ): env = Environment({}) checked_expr = check_expr(env, expr) + import pdb; pdb.set_trace() return checked_expr.checked_type() == typ def test_monomorphic_let(): @@ -30,6 +31,17 @@ def test_single_op(): b.ret(t1) assert has_type(func.to_func(), func_type([float_type()], float_type())) +def test_dual_op(): + "Program: fn (x : float32) { let t1 = f(x); let t2 = g(t1, x); t1 }" + b = IRBuilder() + with b.function(('x', float_type())) as func: + x, = func.param_ids() + t1 = b.let('t1', log(x)) + t2 = b.let('t2', add(t1, x)) + b.ret(t2) + assert has_type(func.to_func(), func_type([float_type()], float_type())) + if __name__ == "__main__": - test_monomorphic_let() + # test_monomorphic_let() test_single_op() + test_dual_op() From 744dc953ae0d99abf20e2fe53174726e7f68d60b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Aug 2018 21:22:32 -0700 Subject: [PATCH 46/88] Fix find and replace bug --- src/codegen/spirv/ir_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 87987dbf08e9..41cb48c5854b 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -41,7 +41,7 @@ void IRBuilder::InitPreDefs() { t_void_.id = id_counter_++; ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_); t_void_func_.id = id_counter_++; - ib_.Begin(spv::OpTypeRelation) + ib_.Begin(spv::OpTypeFunction) .AddSeq(t_void_func_, t_void_).Commit(&global_); } From d70e870e53f0beb1b63ab569db1445c911138fab Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 11:54:35 -0700 Subject: [PATCH 47/88] Add normalization for type relations --- include/tvm/relay/base.h | 17 +++--- include/tvm/relay/expr_visitor.h | 2 +- include/tvm/relay/op.h | 59 +++++++++++++------ include/tvm/relay/pass/alpha_eq.h | 4 +- include/tvm/relay/type.h | 5 +- src/relay/ir/expr.cc | 35 +++++++---- src/relay/ir/type.cc | 24 ++++---- src/relay/op/tensor/elemwise.cc | 10 ++++ src/relay/op/type_relations.cc | 2 +- src/relay/pass/type_infer.cc | 48 +++++++++------ .../relay/test_tyck_eval_integration.py | 2 +- 11 files changed, 130 insertions(+), 78 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 3b31aae52617..f25d6e6532df 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -49,15 +49,16 @@ using NodeEqual = ::tvm::NodeEqual; * \param NodeName The internal contrainer name. * \param NodeRefBase The base type. */ -#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ - class TypeName : public NodeRefBase { \ - public: \ - TypeName() {} \ +#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ + class TypeName : public NodeRefBase { \ + public: \ + TypeName() {} \ explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - using ContainerType = NodeName; \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + operator bool() { return this->defined(); } \ + using ContainerType = NodeName; \ }; diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index d1e8a99dc374..8803aa5ae48f 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -7,7 +7,7 @@ #ifndef TVM_RELAY_EXPR_VISITOR_H_ #define TVM_RELAY_EXPR_VISITOR_H_ -#include "tvm/relay/expr_functor.h" +#include namespace tvm { namespace relay { diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index c91955460f82..0e5483174c53 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -155,6 +155,15 @@ class OpRegistry { */ inline OpRegistry& add_type_func(const std::string & type_func_name, TypeRelationFn type_fn); + /*! + * \brief Attach the type function corresponding to the return type. + * \param ty_func The type function to register for the return type. + * \return reference to self. + */ + inline OpRegistry& add_type_func( + const std::string & type_func_name, + std::function(const Array &, int)> type_fn); + /*! * \brief Set the type key of attributes. * \param type_key The type of of the attrs field.x @@ -343,30 +352,44 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, return *this; } - inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { - auto type_func = TypeRelationNode::make(type_func_name, 0); +inline OpRegistry& OpRegistry::add_type_func( + const std::string & type_func_name, + std::function(const Array &, int)> type_fn) { + auto pfunc = runtime::TypedPackedFunc(const Array &, int)>(type_fn); + return add_type_func(type_func_name, pfunc); +} - std::vector type_params; - std::vector arg_types; - // TODO (@jroesch: revise type generation strategy - int i = 0; - for (auto arg : get()->arguments) { - std::string name = "t"; - name += std::to_string(i++); - auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); - type_params.push_back(param); - arg_types.push_back(param); - } +inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { + auto type_func = TypeRelationNode::make(type_func_name, 0, type_fn); + std::vector type_params; + std::vector arg_types; - auto type_result = TypeCallNode::make(type_func, arg_types); + // Add inputs. + int i = 0; + for (auto arg : get()->arguments) { + std::string name = "in"; + name += std::to_string(i++); + auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); + type_params.push_back(param); + arg_types.push_back(param); + } + + auto ty_call_args = Array(arg_types); + + // Add output type. + auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType); + type_params.push_back(out_param); + ty_call_args.push_back(out_param); - auto func_type = FuncTypeNode::make(arg_types, type_result, type_params, {}); + auto type_result = TypeCallNode::make(type_func, ty_call_args); - get()->op_type = func_type; + auto func_type = FuncTypeNode::make(arg_types, type_result, type_params, {}); - return *this; - } + get()->op_type = func_type; + + return *this; +} inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) get()->num_inputs = n; diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h index caa2f93c31a7..9f3c2138a440 100644 --- a/include/tvm/relay/pass/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -6,8 +6,8 @@ #ifndef TVM_RELAY_ALPHA_EQ_H_ #define TVM_RELAY_ALPHA_EQ_H_ -#include "tvm/relay/type.h" -#include "tvm/relay/expr.h" +#include "../type.h" +#include "../expr.h" namespace tvm { namespace relay { diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 68ed411a23ed..498053f4f9bb 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -210,7 +210,7 @@ class FuncTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); -using TypeRelationFn = std::function(const Array&, int)>; +using TypeRelationFn = runtime::TypedPackedFunc(const Array&, int)>; /*! * \brief Opaque type relation, is an input-output relation on types. @@ -239,7 +239,7 @@ class TypeRelationNode : public RelayNode { v->Visit("num_args", &num_args); } - TVM_DLL static TypeRelation make(std::string name, int num_args); + TVM_DLL static TypeRelation make(std::string name, int num_args, TypeRelationFn func_); static constexpr const char* _type_key = "relay.TypeRelation"; TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, RelayNode); @@ -258,6 +258,7 @@ class TypeCallNode : public TypeNode { public: /*! \brief The type function to be called. */ Type func; + /*! \brief The type arguments to the type function. */ tvm::Array args; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3a3ef1b52604..2b235e8b01ad 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -1,10 +1,10 @@ /*! * Copyright (c) 2018 by Contributors - * \file expr.cc + * \file src/tvm/ir/expr.cc * \brief The expression AST nodes of Relay. */ -#include "tvm/relay/expr.h" -#include "tvm/ir_functor.h" +#include +#include namespace tvm { namespace relay { @@ -29,6 +29,19 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "ConstantNode(TODO)"; }); +TensorType ConstantNode::tensor_type() const { + auto dl_dtype = data->dtype; + auto dtype = HalideIR::Type(static_cast(dl_dtype.code), + dl_dtype.bits, dl_dtype.lanes); + + Array shape; + for (int i = 0; i < data->ndim; i++) { + shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), data->shape[i])); + } + + return TensorTypeNode::make(shape, dtype); +} + Tuple TupleNode::make(tvm::Array fields) { std::shared_ptr n = std::make_shared(); n->fields = std::move(fields); @@ -114,11 +127,8 @@ TVM_REGISTER_API("relay._make.Function") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionNode *node, tvm::IRPrinter *p) { - p->stream << "FunctionNode(" << - node->params << ", " << - node->ret_type << ", " << - node->body << ", " << - node->type_params << ")"; + p->stream << "FunctionNode(" << node->params << ", " << node->ret_type + << ", " << node->body << ", " << node->type_params << ")"; }); Call CallNode::make(Expr op, Array args, Attrs attrs, @@ -158,7 +168,8 @@ TVM_REGISTER_API("relay._make.Let") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { - p->stream << "LetNode(" << node->var << node->value << node->body << node->value_type << ")"; + p->stream << "LetNode(" << node->var << node->value << node->body + << node->value_type << ")"; }); If IfNode::make(Expr cond, Expr true_value, Expr false_value) { @@ -175,10 +186,8 @@ TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { - p->stream << "IfNode(" << - node->cond << ", " << - node->true_value << - node->false_value << ")"; + p->stream << "IfNode(" << node->cond << ", " << node->true_value + << node->false_value << ")"; }); } // namespace relay diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index d9e2737225ec..abed09a69d7b 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -1,11 +1,10 @@ /*! * Copyright (c) 2018 by Contributors - * \file type.cc + * \file src/tvm/ir/type.cc * \brief The type system AST nodes of Relay. */ -#include "tvm/relay/type.h" -#include "tvm/ir_functor.h" - +#include +#include namespace tvm { namespace relay { @@ -42,7 +41,6 @@ TVM_REGISTER_API("relay._make.TensorType") *ret = TensorTypeNode::make(shape, args[1]); }); - TVM_REGISTER_API("relay._make.IntType") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = TensorTypeNode::Int(args[0], args[1]); @@ -91,10 +89,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->kind << ")"; }); - FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints) { + tvm::Array type_params, + tvm::Array type_constraints) { std::shared_ptr n = std::make_shared(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); @@ -116,22 +113,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->type_constraints << ")"; }); -TypeRelation TypeRelationNode::make(std::string name, int num_args) { +TypeRelation TypeRelationNode::make(std::string name, int num_args, TypeRelationFn func) { std::shared_ptr n = std::make_shared(); n->name = std::move(name); n->num_args = std::move(num_args); + n->func_ = std::move(func); return TypeRelation(n); } TVM_REGISTER_API("relay._make.TypeRelation") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TypeRelationNode::make(args[0], args[1]); + *ret = TypeRelationNode::make(args[0], args[1], args[2]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TypeRelationNode *node, - tvm::IRPrinter *p) { - p->stream << "TypeRelationNode(" << node->name << ", " << node->num_args << ")"; + tvm::IRPrinter *p) { + p->stream << "TypeRelationNode(" << node->name << ", " << node->num_args + << ")"; }); TypeCall TypeCallNode::make(Type func, Array args) { @@ -152,6 +151,5 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; }); - } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 700e9185ccba..cd90705c6476 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -39,6 +39,9 @@ RELAY_REGISTER_UNARY_OP("log") .set_support_level(1) .add_type_func("Log", IdentityRel); +// data : Tensor[shape, dtype] +// result: Tensor[shape, dtype] + RELAY_REGISTER_UNARY_OP("exp") .describe(R"code(Returns the exp input array, computed element-wise. @@ -75,5 +78,12 @@ RELAY_REGISTER_OP("add") .set_support_level(1) .add_type_func("Broadcast", BroadcastRel); + // def broadcast(s1, s2): + // ... + // + // input1: Tensor[dtype, s1] + // input2: Tensor[dtype, s2] + // output: Tensor[dtype, broadcast(s1, s2)] + } // namespace relayv } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index a5ba1dc14b5f..68fe2c51a365 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -20,7 +20,7 @@ TensorType as_ttype(const Type & t) { } Array IdentityRel(const Array & types, int num_args) { - CHECK(types.size() == 1); + CHECK(types.size() == 2); auto t1 = as_ttype(types[0]); if (t1 && types[1].as()) { return {t1, t1}; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index b9cfd5837c4b..72fa9cb14bdc 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -62,7 +62,34 @@ struct TypeContext { struct TypeNormalizer : TypeFVisitor { TypeUnifier unifier; TypeNormalizer(const TypeUnifier & unifier) : unifier(unifier) {} - // Type VisitType_( + + Type VisitType_(const TypeCallNode * ty_call_node) { + auto ty_call = GetRef(ty_call_node); + + auto all_concrete = true; + for (auto arg : ty_call->args) { + all_concrete = all_concrete && !arg.as(); + } + + if (all_concrete) { + return ty_call->args[ty_call->args.size() - 1]; + } else { + if (auto ty_rel_node = ty_call->func.as()) { + // NB: we substract 1 for the output argument. + auto new_args = ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); + CHECK(new_args.size() == ty_call->args.size()); + tvm::Array final_args; + + for (int i = 0; i < new_args.size(); i++) { + final_args.push_back(unifier->unify(ty_call->args[i], new_args[i])); + } + + return TypeCallNode::make(ty_call->func, final_args); + } else { + CHECK(false); + } + } + } }; struct CheckedExpr { @@ -167,26 +194,9 @@ class TypeInferencer : private ExprFunctor { } CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { - auto array = const_node->data; - // array->t - // first pass - return { - GetRef(const_node), - TensorTypeNode::make({}, HalideIR::Float(32, 1)) }; + return { GetRef(const_node), const_node->tensor_type() }; } - // Type TypeInferencer::VisitExpr_(const OpIdNode *op) { - // OpId id = GetRef(op); - // Item item = this->env->lookup(id); - - // if (const OpNode *pn = item.as()) { - // Op prim = GetRef(pn); - // return prim->type; - // } else { - // this->fatal_error("internal error in InstrinsicId case", op->span); - // } - // } - CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { // Tuple pl = GetRef(op); diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index d96682fbfda4..e94158cd44e2 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -13,7 +13,7 @@ def has_type(expr, typ): return checked_expr.checked_type() == typ def test_monomorphic_let(): - "Program: let x = 1; x" + "Program: let x = 1; return x" b = IRBuilder() x = b.let('x', 1, value_type=float_type()) b.ret(x) From fd430ea4ea87425732dd0dee0f82cefcd9fa3201 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 15:09:39 -0700 Subject: [PATCH 48/88] Iterating on Broadcast --- include/tvm/relay/type.h | 28 ++++++++- python/tvm/relay/ir_builder.py | 5 +- python/tvm/relay/type.py | 6 +- src/relay/ir/type.cc | 18 ++++++ src/relay/op/type_relations.cc | 58 ++++++++++++++++++- src/relay/pass/type_infer.cc | 16 +++-- .../relay/test_tyck_eval_integration.py | 19 +++--- 7 files changed, 129 insertions(+), 21 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 498053f4f9bb..a6c801c382de 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -269,8 +269,6 @@ class TypeCallNode : public TypeNode { v->Visit("args", &args); } - Type eval() const; - TVM_DLL static TypeCall make(Type func, tvm::Array args); static constexpr const char* _type_key = "relay.TypeCall"; @@ -279,6 +277,32 @@ class TypeCallNode : public TypeNode { RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); +/*! + * \brief The type of tuple values. + */ +class TupleType; +/*! + * \brief TupleType container. + */ +class TupleTypeNode : public TypeNode { + public: + /*! \brief The type of each field in the tuple. */ + tvm::Array fields; + + TupleTypeNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("fields", &fields); + } + + TVM_DLL static TupleType make(tvm::Array fields); + + static constexpr const char* _type_key = "relay.TypeTuple"; + TVM_DECLARE_NODE_TYPE_INFO(TypeTupleNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); + // The following fields contains advanced typing // Only keep the class name and reserved for future usage. class GenericTensorType; diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 8bd225bd4de1..a9cb02a19025 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,7 +1,7 @@ from typing import Any import numpy as np import tvm -from .type import FloatType, IntType, BoolType, UIntType, FuncType +from .type import FloatType, IntType, BoolType, UIntType, FuncType, TensorType from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function from . import op as _op @@ -167,5 +167,8 @@ def float_type(bits=32, lanes=1): def bool_type(lanes=1): return BoolType(lanes) +def tensor_type(*shape, dtype='float32'): + return TensorType(tvm.convert(shape), dtype) + def func_type(args, ret_type, type_params=[], type_constraints=[]): return FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index c7b8964c20e8..c9a96de4889d 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -27,12 +27,12 @@ def same_as(self, other) -> bool: class TensorType(Type): """A concrete TensorType in Relay, see tvm/relay/type.h for more details. """ - dtype: str shape: List[expr.Expr] + dtype: str span: Span - def __init__(self, dtype: str, shape: List[expr.Expr]) -> None: - self.__init_handle_by_constructor__(_make.TensorType,dtype, shape) + def __init__(self, shape: List[expr.Expr], dtype: str) -> None: + self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index abed09a69d7b..1faa9ede8638 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -151,5 +151,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; }); +TypeCall TupleTypeNode::make(Array fields) { + std::shared_ptr n = std::make_shared(); + n->fields = std::move(fields); + return TupleType(n); +} + +TVM_REGISTER_API("relay._make.TupleType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleTypeNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TupleTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "TupleTypeNode(" << node->fields << ")"; + }); + + } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 68fe2c51a365..883b8ecc946d 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -4,6 +4,7 @@ * \brief A set of utilities and common functionality * for type relations. */ +#include #include #include #include "../pass/incomplete_type.h" @@ -19,6 +20,13 @@ TensorType as_ttype(const Type & t) { } } +// TODO(@jroesch) what size value do we extract? +int to_int(const tvm::Expr & e) { + auto imm = e.as(); + CHECK(imm); + return imm->value; +} + Array IdentityRel(const Array & types, int num_args) { CHECK(types.size() == 2); auto t1 = as_ttype(types[0]); @@ -29,12 +37,56 @@ Array IdentityRel(const Array & types, int num_args) { } } +static Type ConcreteBroadcast(const TensorType & t1, const TensorType & t2) { + RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 << std::endl; + auto sh1 = t1->shape; + auto sh2 = t2->shape; + RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2 << std::endl; + CHECK(sh1.size() > 0); + CHECK(sh2.size() > 0); + + auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); + auto full_len = static_cast(std::max(sh1.size(), sh2.size())); + + std::cout << "Length" << suffix_len << full_len << (full_len - suffix_len - 1) << std::endl; + auto lower_bound = full_len - suffix_len - 1; + + for (int64_t i = full_len - 1; i > lower_bound; i--) { + std::cout << "Index i=" << i << std::endl; + auto dim1 = to_int(sh1[i]); + auto dim2 = to_int(sh2[i]); + if (dim1 != dim2) { + CHECK(false); + } + } + + Array larger; + Array smaller; + + for (int i = 0; i < (full_len - suffix_len); i++) { + smaller.push_back(tvm::ir::IntImm::make(1)); + } + + if (sh1.size() < sh2.size()) { + + } else if (sh1.size() > sh2.size()) { + + } else { + + } + + for (int i = 0; i < suffix_len - full_len; i++) { + + } + + return t1; +} + Array BroadcastRel(const Array & types, int num_args) { - std::cout << "Inside of Broadcast" << std::endl; - CHECK(types.size() == 0); + CHECK(types.size() == 3); if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { - return types; + return { t1, t2, ConcreteBroadcast(t1, t2) }; } } return types; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 72fa9cb14bdc..cc91176feb62 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -66,22 +66,28 @@ struct TypeNormalizer : TypeFVisitor { Type VisitType_(const TypeCallNode * ty_call_node) { auto ty_call = GetRef(ty_call_node); - auto all_concrete = true; + Array normalized_args; + for (auto arg : ty_call->args) { + normalized_args.push_back(VisitType(arg)); + } + + auto all_concrete = true; + for (auto arg : normalized_args) { all_concrete = all_concrete && !arg.as(); } if (all_concrete) { - return ty_call->args[ty_call->args.size() - 1]; + return normalized_args[normalized_args.size() - 1]; } else { if (auto ty_rel_node = ty_call->func.as()) { // NB: we substract 1 for the output argument. auto new_args = ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); - CHECK(new_args.size() == ty_call->args.size()); + CHECK(new_args.size() == normalized_args.size()); tvm::Array final_args; for (int i = 0; i < new_args.size(); i++) { - final_args.push_back(unifier->unify(ty_call->args[i], new_args[i])); + final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); } return TypeCallNode::make(ty_call->func, final_args); @@ -606,7 +612,7 @@ class TypeInferencer : private ExprFunctor { Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { try { - return this->unifier->unify(t1, t2); + return Normalize(this->unifier->unify(t1, t2)); } catch (const dmlc::Error &e) { std::stringstream ss; ss << "Error unifying `"; diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index e94158cd44e2..6e5b64ee846e 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -2,7 +2,7 @@ for expressions. """ from tvm.relay.type_infer import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type, func_type +from tvm.relay.ir_builder import IRBuilder, float_type, func_type, tensor_type from tvm.relay.env import Environment from tvm.relay.op import log, add @@ -15,12 +15,11 @@ def has_type(expr, typ): def test_monomorphic_let(): "Program: let x = 1; return x" b = IRBuilder() - x = b.let('x', 1, value_type=float_type()) + x = b.let('x', 1.0, value_type=float_type(64)) b.ret(x) prog = b.get() - assert has_type(prog, float_type()) - + assert has_type(prog, float_type(64)) def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" @@ -32,9 +31,15 @@ def test_single_op(): assert has_type(func.to_func(), func_type([float_type()], float_type())) def test_dual_op(): - "Program: fn (x : float32) { let t1 = f(x); let t2 = g(t1, x); t1 }" + """Program: + fn (x : Tensor[f32, (10, 10)]) { + let t1 = log(x); + let t2 = add(t1, x); + return t1; + } + """ b = IRBuilder() - with b.function(('x', float_type())) as func: + with b.function(('x', tensor_type(10, 10))) as func: x, = func.param_ids() t1 = b.let('t1', log(x)) t2 = b.let('t2', add(t1, x)) @@ -43,5 +48,5 @@ def test_dual_op(): if __name__ == "__main__": # test_monomorphic_let() - test_single_op() + # test_single_op() test_dual_op() From 3e0cb42d16fcb43ce7f8a9e71b32112c5a44417d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 15:36:16 -0700 Subject: [PATCH 49/88] Address CR feedback --- include/tvm/relay/base.h | 9 +++- include/tvm/relay/environment.h | 24 ++++++----- include/tvm/relay/expr.h | 21 +++++----- include/tvm/relay/expr_functor.h | 38 +++-------------- include/tvm/relay/type.h | 4 +- src/relay/ir/type.cc | 2 +- src/relay/op/type_relations.cc | 2 +- src/relay/pass/resolve.cc | 6 +-- src/relay/pass/resolve.h | 6 +-- src/relay/pass/type_functor.h | 2 + src/relay/pass/type_infer.cc | 4 +- src/relay/pass/type_visitor.h | 70 ++++++++++++++++---------------- 12 files changed, 86 insertions(+), 102 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f25d6e6532df..092f5ceb8fc3 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/base.h - * \brief Base data structure for relay. + * \brief Base classes for the Relay IR. */ #ifndef TVM_RELAY_BASE_H_ #define TVM_RELAY_BASE_H_ @@ -13,7 +13,12 @@ namespace tvm { /*! - * \brief Relay: high level functional IR + * \brief Relay: a high level functional IR for TVM. + * + * This namespace contains the abstract syntax tree, and other + * essential data structures for the Relay IR. + * + * You can find more about Relay by reading the language reference. */ namespace relay { /*! diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index ff8e596059b5..ce874103a0a1 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -1,7 +1,8 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/environment.h - * \brief The global environment, contains global state of Relay program. + * \brief The global environment: contains information needed to + * compile & optimize Relay programs. */ #ifndef TVM_RELAY_ENVIRONMENT_H_ #define TVM_RELAY_ENVIRONMENT_H_ @@ -21,18 +22,21 @@ struct Environment; /*! \brief The global environment of Relay programs. * - * The global environment contains all the global - * information needed to compile a Relay program, - * including the set of operators, the set of - * global functions, and configuration options. + * The global environment contains the global + * information needed to compile a Relay program. + * + * It contains all global functions, and configuration + * options. * * Many operations require acess to the global - * Environment. We mostly pass the argument by value - * in a functional style as an explicit argument. + * Environment. We pass the Environment by value + * in a functional style as an explicit argument, + * but we will mutate the Environment while optimizing + * Relay programs. * - * This means users can construct custom environments - * easily, for example a fresh environment for each - * thread while auto-tuning. + * The functional style allows users to construct custom + * environments easily, for example each thread can store + * an Environment while auto-tuning. * */ class EnvironmentNode : public RelayNode { diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index ff11a41a6e5f..5fe91702a29f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/expr.h - * \brief The Relay IR expression nodes. + * \brief Relay expression language. */ #ifndef TVM_RELAY_EXPR_H_ #define TVM_RELAY_EXPR_H_ @@ -16,11 +16,8 @@ namespace tvm { namespace relay { -// TOD0(@jroesch): best way to define? -class TypeInferencer; - /*! - * \brief Relay expression. + * \brief A Relay expression. */ class Expr; /*! @@ -28,7 +25,6 @@ class Expr; */ class ExprNode : public RelayNode { public: - // private: /*! * \brief Stores the result of type inference(type checking). * @@ -48,7 +44,6 @@ class ExprNode : public RelayNode { static constexpr const char* _type_key = "relay.Expr"; TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); - friend class TypeInferencer; }; RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); @@ -68,8 +63,6 @@ class ConstantNode : public ExprNode { /*! \brief The data of the tensor */ runtime::NDArray data; - // TODO(tqchen) add the function after we get TensorType constructor - // TODO(tqchen) create simple TensorType constructor for concrete types. /*! \return The corresponding tensor type of the data */ TensorType tensor_type() const; @@ -335,6 +328,12 @@ RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); /*! * \brief Condition expression + * + * Unlike traditional statement `if`s, the if evalutes + * to the result of the branch taken. + * + * let x = if (true) { 1 } else { 0 }; // x is 1 + * let y = if (false) { 1 } else { 0 }; // y is 0 */ class If; /*! \brief container of If */ @@ -342,9 +341,9 @@ class IfNode : public ExprNode { public: /*! \brief The condition */ Expr cond; - /*! \brief The value to take when condition is true */ + /*! \brief The expression evaluated when condition is true. */ Expr true_value; - /*! \brief The value to take when condition is false */ + /*! \brief The expression evaluated when condition is false */ Expr false_value; IfNode() {} diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e37a454eee41..4632733cbcfc 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -1,8 +1,8 @@ /*! * Copyright (c) 2018 by Contributors - * \file expr_functor.h - * \brief A more powerful Visitor that enables defining arbitrary function - * signatures with dispatch on first argument. + * \file tvm/relay/expr_functor.h + * \brief A more powerful visitor which enables defining arbitrary function + * signatures with type based dispatch on first argument. */ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_ @@ -19,36 +19,8 @@ namespace relay { * \brief A dynamical functor that dispatches on in the first Expr argument. * You can use this as a more powerful Visitor, since it allows you to * define function signatures of Visit Function. - * - * This helps you to avoid to book-keep return value of Visitor via state, - * which can cause bugs easily when state is incorrectly maintained. - * - * \code - * // A functor that set variable to b. and calculate results. - * class MyExprFunctor - * : public ir::ExprFunctor { - * public: - * int VisitExpr_(const Variable* op, int b) final { - * return b; - * } - * int VisitExpr_(const IntImm* op, int b) final { - * return op->value; - * } - * int VisitExpr_(const Add* op, int b) final { - * return Visit(op->a, b) + Visit(op->b, b); - * } - * }; - * MyExprFunctor f; - * Var x("x"); - * CHECK_EQ(f(x + 1, 2), 3); - * \endcode - * - * \note Why do we need this more powerful Functor: - * - * We often need to implement a transformer tasks. - * Say we want to take Expr and transform it to some analysis result, - * This easily be done incorrectly using plain Visitor. See IRVisitor's - * document for possible error cases. + * + * \sa tvm/ir_functor.h * * \tparam FType function signiture * This type if only defined for FType with function signiture R(const Expr&, diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index a6c801c382de..5d579b661280 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -298,10 +298,10 @@ class TupleTypeNode : public TypeNode { TVM_DLL static TupleType make(tvm::Array fields); static constexpr const char* _type_key = "relay.TypeTuple"; - TVM_DECLARE_NODE_TYPE_INFO(TypeTupleNode, TypeNode); + TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); +RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); // The following fields contains advanced typing // Only keep the class name and reserved for future usage. diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 1faa9ede8638..e29f3cbde4c1 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -151,7 +151,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; }); -TypeCall TupleTypeNode::make(Array fields) { +TupleType TupleTypeNode::make(Array fields) { std::shared_ptr n = std::make_shared(); n->fields = std::move(fields); return TupleType(n); diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 883b8ecc946d..56b139731178 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -64,7 +64,7 @@ static Type ConcreteBroadcast(const TensorType & t1, const TensorType & t2) { Array smaller; for (int i = 0; i < (full_len - suffix_len); i++) { - smaller.push_back(tvm::ir::IntImm::make(1)); + // smaller.push_back(tvm::ir::IntImm::make(1)); } if (sh1.size() < sh2.size()) { diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index e86368854060..f18a67bcffc9 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -64,12 +64,12 @@ struct ResolveTypeExpr : ExprFVisitor<> { } }; -Type resolve(const TypeUnifier &unifier, const Type &ty) { +Type Resolve(const TypeUnifier &unifier, const Type &ty) { CHECK(ty.defined()); return ResolveTypeType(unifier).VisitType(ty); } -Expr resolve(const TypeUnifier &unifier, const Expr &expr) { +Expr Resolve(const TypeUnifier &unifier, const Expr &expr) { return ResolveTypeExpr(unifier).VisitExpr(expr); } @@ -91,7 +91,7 @@ struct FullyResolved : TypeVisitor<> { } }; -bool is_fully_resolved(const Type &t) { +bool IsFullyResolved(const Type &t) { auto fr = FullyResolved(); fr.VisitType(t); return fr.incomplete; diff --git a/src/relay/pass/resolve.h b/src/relay/pass/resolve.h index 5f6cc328a239..495c9658238a 100644 --- a/src/relay/pass/resolve.h +++ b/src/relay/pass/resolve.h @@ -13,9 +13,9 @@ namespace tvm { namespace relay { -Type resolve(const TypeUnifier & unifier, const Type & ty); -Expr resolve(const TypeUnifier & unifier, const Expr & expr); -bool is_fully_resolved(const Type & t); +Type Resolve(const TypeUnifier & unifier, const Type & ty); +Expr Resolve(const TypeUnifier & unifier, const Expr & expr); +bool IsFullyResolved(const Type & t); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h index 9adc1a08860e..9180703b49e8 100644 --- a/src/relay/pass/type_functor.h +++ b/src/relay/pass/type_functor.h @@ -65,6 +65,7 @@ class TypeFunctor { virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Node* op, Args...) { @@ -83,6 +84,7 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); return vtable; } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index cc91176feb62..5e9e784dbe83 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -526,7 +526,7 @@ class TypeInferencer : private ExprFunctor { Type TypeInferencer::resolve(const Type &t) { if (t.defined()) { - return ::tvm::relay::resolve(this->unifier, t); + return ::tvm::relay::Resolve(this->unifier, t); } else { return IncompleteTypeNode::make(TypeParamNode::Kind::kType); } @@ -534,7 +534,7 @@ class TypeInferencer : private ExprFunctor { Expr TypeInferencer::resolve(const Expr &e) { CHECK(e.defined()); - return ::tvm::relay::resolve(this->unifier, e); + return ::tvm::relay::Resolve(this->unifier, e); } void TypeInferencer::CheckOp(Op op) { diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index 68dba76644c3..f3c0f9a74fb7 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -22,7 +22,7 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { void VisitType_(const TypeParamNode* op, Args... args) override {} void VisitType_(const FuncTypeNode* op, Args... args) override { - // fix me handle poly + // TODO(@jroesch): handle poly // this->VisitType(op->var, args...); // this->VisitType(op->boundType, args...); for (auto arg_type : op->arg_types) { @@ -33,11 +33,11 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { void VisitType_(const TensorTypeNode* op, Args... args) override {} - // void VisitType_(const TupleTypeNode* op, Args... args) override { - // for (const Type& t : op->fields) { - // this->VisitType(t, args...); - // } - // } + void VisitType_(const TupleTypeNode* op, Args... args) override { + for (const Type& t : op->fields) { + this->VisitType(t, args...); + } + } void VisitType_(const TypeCallNode* op, Args... args) override { this->VisitType(op->func, args...); @@ -63,46 +63,48 @@ struct TypeFVisitor : TypeFunctor { } Type VisitType_(const FuncTypeNode* op) override { + // TODO (@jroesch): handle poly + // auto new_id = this->VisitType(op->var); // if (const TypeParamNode* tin = new_id.as()) { // return TypeQuantifierNode::make(GetRef(tin), // this->VisitType(op->boundType)); - std::vector args; - for (auto arg_type : op->arg_types) { - args.push_back(VisitType(arg_type)); - } - - return FuncTypeNode::make(tvm::Array(args), - VisitType(op->ret_type), {}, {}); // fix me + std::vector args; + for (auto arg_type : op->arg_types) { + args.push_back(VisitType(arg_type)); } - // Type VisitType_(const TupleTypeNode* op) override { - // std::vector new_fields; - // for (const Type& t : op->fields) { - // new_fields.push_back(this->VisitType(t)); - // } - // return TupleTypeNode::make(new_fields); - // } - - Type VisitType_(const TypeRelationNode* op) override { - return GetRef(op); - } + return FuncTypeNode::make(tvm::Array(args), VisitType(op->ret_type), + {}, {}); // fix me + } - Type VisitType_(const TypeCallNode* op) override { - auto func = this->VisitType(op->func); - std::vector new_args; - for (const Type& t : op->args) { - new_args.push_back(this->VisitType(t)); + Type VisitType_(const TupleTypeNode* op) override { + std::vector new_fields; + for (const Type& t : op->fields) { + new_fields.push_back(this->VisitType(t)); } - return TypeCallNode::make(func, new_args); + return TupleTypeNode::make(new_fields); } - Type VisitType_(const IncompleteTypeNode* op) override { - return GetRef(op); + Type VisitType_(const TypeRelationNode* op) override { + return GetRef(op); + } + + Type VisitType_(const TypeCallNode* op) override { + auto func = this->VisitType(op->func); + std::vector new_args; + for (const Type& t : op->args) { + new_args.push_back(this->VisitType(t)); } - }; + return TypeCallNode::make(func, new_args); + } + + Type VisitType_(const IncompleteTypeNode* op) override { + return GetRef(op); + } +}; } // namespace relay -} // namespace relay +} // namespace tvm #endif // TVM_RELAY_TYPE_VISITOR_H_ From c0c8d574832bf56c76d21f8385ecdb3fff752962 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 15:46:45 -0700 Subject: [PATCH 50/88] Address more CR feedback --- include/tvm/relay/expr_visitor.h | 9 ++++++--- include/tvm/relay/logging.h | 2 +- python/tvm/relay/__init__.py | 1 - 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 8803aa5ae48f..e15f25a39eb3 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -1,8 +1,11 @@ /*! * Copyright (c) 2018 by Contributors - * \file expr_visitor.h - * \brief A simple visitor wrapper around ExprFunctor designed for visitors which - * maintain mutable state. + * \file tvm/relay/expr_visitor.h + * \brief A simple visitor wrapper around ExprFunctor. + * + * Exposes two visitors with default traversal strategies, one + * which doesn't compute a result but can mutate internal state, + * and another which functionally builds a new Expr. */ #ifndef TVM_RELAY_EXPR_VISITOR_H_ #define TVM_RELAY_EXPR_VISITOR_H_ diff --git a/include/tvm/relay/logging.h b/include/tvm/relay/logging.h index 99cfc44de6cb..c53cd15ee72e 100644 --- a/include/tvm/relay/logging.h +++ b/include/tvm/relay/logging.h @@ -8,10 +8,10 @@ #ifndef TVM_RELAY_LOGGING_H_ #define TVM_RELAY_LOGGING_H_ +#include #include #include #include -#include "dmlc/logging.h" namespace tvm { namespace relay { diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 019d7c19a865..c36b9bcf8357 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -18,7 +18,6 @@ # Expr Constant = expr.Constant Tuple = expr.Tuple -# TODO: GlobalVar, LocalVar-> var LocalVar = expr.LocalVar GlobalVar = expr.GlobalVar Param = expr.Param From 7610941c9585bbcc12573e3081e88825f1537a2d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:24:47 -0700 Subject: [PATCH 51/88] Add SourceMap and clean up environment.h --- include/tvm/relay/environment.h | 12 ++--- include/tvm/relay/source_map.h | 44 +++++++++++++++++ src/relay/source_map.cc | 88 +++++++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 6 deletions(-) create mode 100644 include/tvm/relay/source_map.h create mode 100644 src/relay/source_map.cc diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index ce874103a0a1..43be0ab8c912 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -13,7 +13,7 @@ #include "./type.h" #include "./op.h" #include "./error.h" -// #include "tvm/relay/source_map.h" +#include "tvm/relay/source_map.h" namespace tvm { namespace relay { @@ -43,10 +43,10 @@ class EnvironmentNode : public RelayNode { private: /*! \brief A map from string names to global variables ensures global uniqueness. */ tvm::Map global_map_; - // /*! \brief A map from file names to source fragments. */ - // SourceMap source_map_ - // /*! \brief A list of the errors reported during the current run. */ - // std::vector errors_; + /*! \brief A map from file names to source fragments. */ + SourceMap source_map_; + /*! \brief A list of the errors reported during the current run. */ + std::vector errors_; public: /*! \brief A map from ids to all global functions. */ @@ -73,7 +73,7 @@ class EnvironmentNode : public RelayNode { Function Lookup(const std::string & s); /*! \brief Add a source fragment to the environment. */ - // FileId add_source(std::string file_name, std::string source); + SourceName AddSource(std::string file_name, std::string source); void ReportError(std::string msg, Span sp); void DisplayErrors(); diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h new file mode 100644 index 000000000000..71bf93aa1ed9 --- /dev/null +++ b/include/tvm/relay/source_map.h @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file source_map.h + * \brief A representation of source files and a data structure for + * storing them. + */ +#ifndef TVM_RELAY_SOURCE_MAP_H_ +#define TVM_RELAY_SOURCE_MAP_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +struct SourceFragment { + std::string file_name; + std::vector source_lines; + + SourceFragment(std::string file_name, std::string source); + + SourceFragment(const SourceFragment& sf) { + this->file_name = sf.file_name; + this->source_lines = sf.source_lines; + } + + std::string SourceAt(Span sp, int lines); +}; + +/*! \brief Maps from FileId's to a SourceFragment. + */ +class SourceMap { + /*! \brief Map from unique token to a fragment of a source file. */ + std::unordered_map map_; + public: + SourceMap() : map_() {} + SourceName AddSource(std::string file_name, std::string source); + const SourceFragment & GetSource(SourceName id) const; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_SOURCE_MAP_H_ \ No newline at end of file diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc new file mode 100644 index 000000000000..0db80fd30339 --- /dev/null +++ b/src/relay/source_map.cc @@ -0,0 +1,88 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file source_map.cc + * \brief Source maps for Relay. + */ + +#include +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +SourceFragment::SourceFragment(std::string file_name, std::string source) + : file_name(file_name), source_lines({}) { + RELAY_LOG(INFO)<< "SourceFragment::SourceFragment source=" << source << std::endl; + std::stringstream source_stream; + source_stream.str(source.c_str()); + std::string line; + + while (std::getline(source_stream, line)) { + RELAY_LOG(INFO) << "SourceFragment::SourceFragment: line=" << line << std::endl; + std::string copy(line); + source_lines.push_back(copy); + } +} + +std::string SourceFragment::SourceAt(Span sp, int max_lines) { + std::stringstream out; + + // We need to move from 1 based indexing to zero based indexing. + int starting_line = sp->lineno; + + if (starting_line >= static_cast(this->source_lines.size())) { + throw dmlc::Error("SourceFragment: index out of bounds"); + } + + auto lines = std::max(static_cast(max_lines), source_lines.size() - starting_line); + + for (size_t i = 0; i < lines; i++) { + out << std::endl << this->source_lines.at(starting_line + i); + } + + auto source_slice = out.str(); + + RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice << std::endl; + return source_slice; +} + +SourceName SourceMap::AddSource(std::string file_name, std::string source) { + auto new_id = SourceNameNode::make(file_name); + SourceFragment sfile(file_name, source); + this->map_.insert({new_id, sfile}); + return new_id; +} + +SourceName SourceNameNode::make(std::string name) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + return SourceName(n); +} + +static SourceFragment DUMMY_SOURCE = SourceFragment("DUMMY_FILE", "DUMMY_SOURCE"); + +SourceFragment const &SourceMap::GetSource(SourceName id) const { + auto item = map_.find(id); + if (item != map_.end()) { + return (*item).second; + } else { + return DUMMY_SOURCE; + } +} + +TVM_REGISTER_API("relay._make.SourceName") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { + *ret = SourceNameNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { + p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; + }); + +} // namespace relay +} // namespace tvm \ No newline at end of file From 0531b737b219ba2d5f4af8b70122049665c3c4d9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:35:43 -0700 Subject: [PATCH 52/88] Reogranize a bit --- include/tvm/relay/expr.h | 4 ++-- src/relay/ir/base.cc | 18 ++++++++++++++++++ src/relay/source_map.cc | 22 ++-------------------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 5fe91702a29f..ddac633f9d09 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include "./base.h" #include "./type.h" @@ -223,8 +224,7 @@ class FunctionNode : public ExprNode { RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); -// TODO(tqchen) change Expr to Attr after we introduce Attr system. -using Attrs = tvm::Map; +using Attrs = tvm::Attrs; /*! * \brief Call corresponds to operator invocation. diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 5fdf96ded224..d48b9a4c3e0c 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -12,6 +12,24 @@ namespace relay { using tvm::IRPrinter; using namespace tvm::runtime; +SourceName SourceNameNode::make(std::string name) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + return SourceName(n); +} + +// TVM_REGISTER_API("relay._make.SourceName") +// .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { +// *ret = SourceNameNode::make(args[0]); +// }); + +// This causes a crash? + +// TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +// .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { +// p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; +// }); + Span SpanNode::make(SourceName source, int lineno, int col_offset) { std::shared_ptr n = std::make_shared(); n->source = std::move(source); diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc index 0db80fd30339..a1b3627bccc8 100644 --- a/src/relay/source_map.cc +++ b/src/relay/source_map.cc @@ -57,32 +57,14 @@ SourceName SourceMap::AddSource(std::string file_name, std::string source) { return new_id; } -SourceName SourceNameNode::make(std::string name) { - std::shared_ptr n = std::make_shared(); - n->name = std::move(name); - return SourceName(n); -} - -static SourceFragment DUMMY_SOURCE = SourceFragment("DUMMY_FILE", "DUMMY_SOURCE"); - -SourceFragment const &SourceMap::GetSource(SourceName id) const { +const SourceFragment& SourceMap::GetSource(SourceName id) const { auto item = map_.find(id); if (item != map_.end()) { return (*item).second; } else { - return DUMMY_SOURCE; + throw dmlc::Error("could not find requested source fragment"); } } -TVM_REGISTER_API("relay._make.SourceName") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { - *ret = SourceNameNode::make(args[0]); - }); - -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { - p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; - }); - } // namespace relay } // namespace tvm \ No newline at end of file From aacf30be58e1ea6b83c22a362bca3b1842add709 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:45:02 -0700 Subject: [PATCH 53/88] Kill dead code in env.py --- python/tvm/relay/_env.pyi | 15 +------------- python/tvm/relay/env.py | 43 --------------------------------------- 2 files changed, 1 insertion(+), 57 deletions(-) diff --git a/python/tvm/relay/_env.pyi b/python/tvm/relay/_env.pyi index d14e726e5443..c6b5d0f6c4bd 100644 --- a/python/tvm/relay/_env.pyi +++ b/python/tvm/relay/_env.pyi @@ -2,17 +2,4 @@ from typing import Union, Tuple, Dict, List from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId from relay.ir import ShapeExtension, Operator, Defn -class Environment(NodeBase): ... - -def Environment_add(self: Environment, func: GlobalId) -> None: ... -def Environment_global_id(self: Environment, name: str) -> GlobalId: ... -def Environment_operator_id(self: Environment, name: str) -> OperatorId: ... -def Environment_lookup_global(self: Environment, id: GlobalId) -> Item: ... -def Environment_lookup_operator(self: Environment, id: OperatorId) -> Item: ... -def Environment_remove_global(self: Environment, id: GlobalId) -> Item: ... -def Environment_add_source(self: Environment, file_name: str, source: str) -> FileId: ... -def Environment_report_error(self: Environment, message: str, span: Span) -> None: ... -def Environment_display_errors(self: Environment) -> None: ... -def Environment_register_shape_ext(self: Environment, shape_ext: ShapeExtension) -> None: ... -def Environment_get_operators(self: Environment) -> List[Operator]: ... -def Environment_get_defns(self: Environment) -> List[Defn]: ... +class Environment(NodeBase): ... \ No newline at end of file diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index c63197fa8509..4de5a0c02772 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -3,15 +3,8 @@ from typing import Union, List from .base import register_relay_node, NodeBase from . import _make -# from relay.ir import GlobalId, OperatorId, Item, FileId, Span, ShapeExtension -# from relay.ir import Operator, Defn -# from relay._env import * import tvm -# Move me to C++ if possible. -__tgt_host__ = __tgt__ = "llvm" -__relay_tvm_context__ = tvm.cpu() - @register_relay_node class Environment(NodeBase): """The global Relay environment containing definitions, @@ -19,39 +12,3 @@ class Environment(NodeBase): """ def __init__(self, funcs) -> None: self.__init_handle_by_constructor__(_make.Environment, funcs) - - # def add(self, item: Item) -> None: - # return Environment_add(self, item) - - # def global_id(self, name: str) -> GlobalId: - # return Environment_global_id(self, name) - - # def operator_id(self, name: str) -> OperatorId: - # return Environment_operator_id(self, name) - - # def lookup(self, ident: Union[GlobalId, OperatorId]) -> Item: - # if isinstance(ident, OperatorId): - # return Environment_lookup_operator(self, ident) - # else: - # return Environment_lookup_global(self, ident) - - # def add_source(self, file_name: str, source: str) -> FileId: - # return Environment_add_source(self, file_name, source) - - # def report_error(self, message: str, span: Span) -> None: - # return Environment_report_error(self, message, span) - - # def register_shape_ext(self, ext: ShapeExtension) -> None: - # return Environment_register_shape_ext(self, ext) - - # def display_errors(self) -> None: - # return Environment_display_errors(self) - - # def operators(self) -> List[Operator]: - # return Environment_get_operators(self) - - # def defns(self) -> List[Defn]: - # return Environment_get_defns(self) - - # def tvm_context(self): - # return __relay_tvm_context__ From 6cad8668325d57590cb720c93e4eb71b3bbc4ec8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:45:31 -0700 Subject: [PATCH 54/88] Fix commit mistake --- cmake/config.cmake | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index e09fdb241bf1..c364a88cce11 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -19,9 +19,6 @@ # $ make -j8 #-------------------------------------------------------------------- -SET(CMAKE_C_COMPLIER clang) -SET(CMAKE_CXX_COMPILER clang++) - #--------------------------------------------- # Backend runtimes. #--------------------------------------------- From 61f3534d576a14897082c01007805cc3ec06165b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:48:05 -0700 Subject: [PATCH 55/88] Move type_infer into pass.py --- python/tvm/relay/{_type_infer.py => _pass.py} | 0 python/tvm/relay/{_type_infer.pyi => _pass.pyi} | 0 python/tvm/relay/{type_infer.py => pass.py} | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename python/tvm/relay/{_type_infer.py => _pass.py} (100%) rename python/tvm/relay/{_type_infer.pyi => _pass.pyi} (100%) rename python/tvm/relay/{type_infer.py => pass.py} (78%) diff --git a/python/tvm/relay/_type_infer.py b/python/tvm/relay/_pass.py similarity index 100% rename from python/tvm/relay/_type_infer.py rename to python/tvm/relay/_pass.py diff --git a/python/tvm/relay/_type_infer.pyi b/python/tvm/relay/_pass.pyi similarity index 100% rename from python/tvm/relay/_type_infer.pyi rename to python/tvm/relay/_pass.pyi diff --git a/python/tvm/relay/type_infer.py b/python/tvm/relay/pass.py similarity index 78% rename from python/tvm/relay/type_infer.py rename to python/tvm/relay/pass.py index 17938dfdcbc4..9d7902686928 100644 --- a/python/tvm/relay/type_infer.py +++ b/python/tvm/relay/pass.py @@ -1,6 +1,6 @@ #pylint: disable-all -from . import _type_infer +from . import _pass check_expr = _type_infer.check_expr # generalize = _type_infer.generalize From 37753b6cad8f23235d556416b66c4b3f896c8d4c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:54:18 -0700 Subject: [PATCH 56/88] Reorganize passes a bit --- include/tvm/relay/pass.h | 23 +++++++++++ include/tvm/relay/pass/type_infer.h | 10 +---- python/tvm/relay/_pass.py | 2 +- src/relay/pass/type_infer.cc | 59 ++--------------------------- 4 files changed, 29 insertions(+), 65 deletions(-) create mode 100644 include/tvm/relay/pass.h diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h new file mode 100644 index 000000000000..89f3dd48fc31 --- /dev/null +++ b/include/tvm/relay/pass.h @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/pass.h + * \brief The set of Relay passes written in C++. + */ +#ifndef TVM_RELAY_PASS_H_ +#define TVM_RELAY_PASS_H_ + +#include "tvm/relay/expr.h" +#include "tvm/relay/environment.h" + +namespace tvm { +namespace relay { + +/*! The result of type checking an expression is a new expression + * with unambigous type information filled in, as well as it's + * checked type field populated with the result type. + */ +Expr InferType(const Environment & env, const Expr & e); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_TYPECHECKER_H_ \ No newline at end of file diff --git a/include/tvm/relay/pass/type_infer.h b/include/tvm/relay/pass/type_infer.h index 9a8ab2bc6a8b..2b860a5e89ef 100644 --- a/include/tvm/relay/pass/type_infer.h +++ b/include/tvm/relay/pass/type_infer.h @@ -6,8 +6,8 @@ * The pass produces a new expression with its checked_type * field populated and incomplete types resolved. */ -#ifndef TVM_RELAY_PASS__TYPECHECKER_H_ -#define TVM_RELAY_PASS__TYPECHECKER_H_ +#ifndef TVM_RELAY_PASS_TYPECHECKER_H_ +#define TVM_RELAY_PASS_TYPECHECKER_H_ #include "tvm/relay/expr.h" #include "tvm/relay/environment.h" @@ -15,12 +15,6 @@ namespace tvm { namespace relay { -/*! The result of type checking an expression is a new expression - * with unambigous type information filled in, as well as it's - * checked type field populated with the result type. - */ -Expr InferType(const Environment & env, const Expr & e); - /*! \brief Ensures that an operator is well-formed with respect * to Relay's type system. */ diff --git a/python/tvm/relay/_pass.py b/python/tvm/relay/_pass.py index 7213769a4164..052ba6d4a0fb 100644 --- a/python/tvm/relay/_pass.py +++ b/python/tvm/relay/_pass.py @@ -2,4 +2,4 @@ from tvm._ffi.function import _init_api -_init_api("relay._type_infer", __name__) +_init_api("relay._pass", __name__) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 5e9e784dbe83..746323fc6d56 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -626,73 +626,20 @@ class TypeInferencer : private ExprFunctor { } } - // // template - - // // Add safe dynamic Array downcast. - // // Add static upcast? - - // // Add to type utils. - // Array type_parameters(const Type &t) { - // Array params; - // auto type = t; - // const TypeQuantifierNode *ty_quant; - // while ((ty_quant = type.as())) { - // params.push_back(ty_quant->id); - // type = ty_quant->boundType; - // } - - // return params; - // } - - // template - // Array ArrayMap(const Array &data, F f) { - // // probably a way to use std::transform. - // Array output; - // for (const I &el : data) { - // output.push_back(f(el)); - // } - // return output; - // } - - // // There are some important questions around generalization - // // that we need to answer. - // Expr generalize(const Environment &env, const Expr &e) { - // if (auto fn_node = e.as()) { - // TypeInferencer tc(env); - // auto ty = tc.VisitFunction(GetRef(fn_node), true); - // auto ty_params = type_parameters(ty); - // auto params = ArrayMap(fn_node->params, [&](const Param &p) { - // return ParamNode::make(p->id, tc.resolve(p->type)); - // }); - // auto body = tc.resolve(fn_node->body); - // auto ret_type = tc.resolve(fn_node->ret_type); - // auto fn = FunctionNode::make(ty_params, params, ret_type, body); - // // we should check in empty context to ensure typing is preserved. - // // check(env, fn); - // return fn; - // } else { - // throw dmlc::Error("can only apply generalize to a function."); - // } - // } - - TVM_REGISTER_API("relay._type_infer.check_expr") + TVM_REGISTER_API("relay._pass.check_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { Environment env = args[0]; Expr e = args[1]; *ret = Infer(env, e); }); - TVM_REGISTER_API("relay._type_infer._get_checked_type") + // TODO(@jroesch): put in a better namespace. + TVM_REGISTER_API("relay._pass._get_checked_type") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr e = args[0]; *ret = e->checked_type(); }); - // TVM_REGISTER_API("relay._tyck.generalize") - // .set_body([](TVMArgs args, TVMRetValue *ret) { - // *ret = generalize(args[0], args[1]); - // }); - IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { std::shared_ptr n = std::make_shared(); From 5cd2bbe6c9da42f1ea06fff8d198007f3e53b550 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:57:12 -0700 Subject: [PATCH 57/88] More cleaning --- python/tvm/relay/type.py | 37 ------------------------------------- src/relay/ir/type.cc | 26 +++----------------------- 2 files changed, 3 insertions(+), 60 deletions(-) diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index c9a96de4889d..70e4666e96f9 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -4,7 +4,6 @@ from enum import IntEnum from .base import Span, NodeBase, register_relay_node from tvm import expr -# TODO(@jroesch): move me from . import _make class Type(NodeBase): @@ -85,39 +84,3 @@ class IncompleteType(Type): def __init__(self, kind: Kind) -> None: self.__init_handle_by_constructor__(_make.IncompleteType, kind) - -def IntType(bits: int, lanes: int=1) -> Type: - """Constructs a integer base type. - - :param bits: The bit width of the integer type. - :param lanes: The number of vector elements for this datatype. - - """ - return _make.IntType(bits, lanes) - - -def UIntType(bits: int, lanes: int=1) -> Type: - """Constructs a unsigned integer base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.UIntType(bits, lanes) - - -def FloatType(bits: int, lanes: int=1) -> Type: - """Constructs a floating point base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.FloatType(bits, lanes) - - -def BoolType(lanes: int =1) -> Type: - """Constructs a boolean base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.BoolType(lanes) diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index e29f3cbde4c1..2975c60cc0c1 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -41,26 +41,6 @@ TVM_REGISTER_API("relay._make.TensorType") *ret = TensorTypeNode::make(shape, args[1]); }); -TVM_REGISTER_API("relay._make.IntType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TensorTypeNode::Int(args[0], args[1]); - }); - -TVM_REGISTER_API("relay._make.UIntType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TensorTypeNode::UInt(args[0], args[1]); - }); - -TVM_REGISTER_API("relay._make.BoolType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TensorTypeNode::Bool(args[0]); - }); - -TVM_REGISTER_API("relay._make.FloatType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TensorTypeNode::Float(args[0], args[1]); - }); - TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TensorTypeNode *node, tvm::IRPrinter *p) { @@ -113,7 +93,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->type_constraints << ")"; }); -TypeRelation TypeRelationNode::make(std::string name, int num_args, TypeRelationFn func) { +TypeRelation TypeRelationNode::make(std::string name, int num_args, + TypeRelationFn func) { std::shared_ptr n = std::make_shared(); n->name = std::move(name); n->num_args = std::move(num_args); @@ -164,10 +145,9 @@ TVM_REGISTER_API("relay._make.TupleType") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TupleTypeNode *node, - tvm::IRPrinter *p) { + tvm::IRPrinter *p) { p->stream << "TupleTypeNode(" << node->fields << ")"; }); - } // namespace relay } // namespace tvm From 828a62b915cb1a069d5cc8231b844a4ea3f8c299 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 17:00:55 -0700 Subject: [PATCH 58/88] Remove dead code --- python/tvm/relay/pass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/pass.py b/python/tvm/relay/pass.py index 9d7902686928..8c352e58843d 100644 --- a/python/tvm/relay/pass.py +++ b/python/tvm/relay/pass.py @@ -3,4 +3,3 @@ from . import _pass check_expr = _type_infer.check_expr -# generalize = _type_infer.generalize From c6bbb860e3d83c7ce5ba904605f7a9a7bf2ff2e7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 17:06:56 -0700 Subject: [PATCH 59/88] Clean up code in Unifier --- src/relay/pass/unifier.cc | 69 ++++++++++++++++----------------------- src/relay/pass/unifier.h | 2 +- 2 files changed, 29 insertions(+), 42 deletions(-) diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index c6a4e7dfba6d..4d986ad79ab1 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -11,7 +11,6 @@ #include #include "./unifier.h" #include "./type_visitor.h" -//#include "./type_subst.h" // #include "tvm/relay/typeck/kindchecker.h" namespace tvm { @@ -298,52 +297,40 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { throw UnificationError("Cannot unify TensorTypeNode"); } -// Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { -// TupleType pt1 = GetRef(t1); +Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { + TupleType pt1 = GetRef(t1); -// // for typevar, remap and attempt to unify if already defined -// if (const IncompleteTypeNode *tvn2 = rt2.as()) { -// return this->unifyWithIncompleteType(pt1, GetRef(tvn2)); -// } + // When unifying tuple types we just solve each field in order. + if (const TupleTypeNode *ptn2 = rt2.as()) { + TupleType pt2 = GetRef(ptn2); -// // for other product types, unify item by item -// if (const TupleTypeNode *ptn2 = rt2.as()) { -// TupleType pt2 = GetRef(ptn2); - -// std::vector unified_fields; -// if (pt1->fields.size() != pt2->fields.size()) { -// throw UnificationError("Product types are of different dimensions"); -// } - -// for (size_t i = 0U; i < pt1->fields.size(); i++) { -// Type unified = this->VisitType(pt1->fields[i], pt2->fields[i]); -// unified_fields.push_back(unified); -// } - -// return TupleTypeNode::make(unified_fields); -// } - -// // otherwise cannot unify -// throw UnificationError("Cannot unify TupleTypeNode"); -// } + std::vector unified_fields; + if (pt1->fields.size() != pt2->fields.size()) { + throw UnificationError("Product types are of different dimensions"); + } -Type TypeUnifierNode::VisitType_(const TypeRelationNode *sen1, const Type t2) { -// ShapeExtension sh_ext1 = GetRef(sen1); + for (size_t i = 0U; i < pt1->fields.size(); i++) { + Type unified = this->VisitType(pt1->fields[i], pt2->fields[i]); + unified_fields.push_back(unified); + } -// if (const IncompleteTypeNode *tvn2 = t2.as()) { -// return this->unifyWithIncompleteType(sh_ext1, GetRef(tvn2)); -// } + return TupleTypeNode::make(unified_fields); + } -// // will only attempt to unify with binary op with same op -// if (const ShapeExtensionNode *sen2 = t2.as()) { -// if (sh_ext1->name != sen2->name) { -// throw UnificationError( -// "Cannot unify shape projections of different index"); -// } -// } + // otherwise cannot unify + throw UnificationError("Cannot unify TupleTypeNode"); +} -// return sh_ext1; - return t2; +Type TypeUnifierNode::VisitType_(const TypeRelationNode *tr1, const Type t2) { + if (const TypeRelationNode *tr2 = t2.as()) { + if (tr1 == tr2) { + return GetRef(tr1); + } else { + throw UnificationError("Cannot unify different type relations"); + } + } else { + throw UnificationError("Cannot unify type relation with another type of type"); + } } Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index aecc428cb6a9..5a4adea5c44e 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -109,7 +109,7 @@ class TypeUnifierNode : public Node, Type VisitType_(const TensorTypeNode* t1, const Type t2) override; Type VisitType_(const TypeParamNode* t1, const Type t2) override; Type VisitType_(const FuncTypeNode* t1, const Type t2) override; - // Type VisitType_(const TupleTypeNode* t1, const Type t2) override; + Type VisitType_(const TupleTypeNode* t1, const Type t2) override; Type VisitType_(const TypeRelationNode* s1, const Type t2) override; Type VisitType_(const TypeCallNode* s1, const Type t2) override; }; From 408b51a012264ab9c04e37d5281fd74102da4ba5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 17:21:37 -0700 Subject: [PATCH 60/88] Clean up environment.h --- include/tvm/relay/environment.h | 3 +- src/relay/ir/environment.cc | 203 ++++++++++---------------------- 2 files changed, 60 insertions(+), 146 deletions(-) diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 43be0ab8c912..aa41882db46e 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -50,7 +50,7 @@ class EnvironmentNode : public RelayNode { public: /*! \brief A map from ids to all global functions. */ - tvm::Map items; + tvm::Map functions; EnvironmentNode() {} @@ -60,7 +60,6 @@ class EnvironmentNode : public RelayNode { tvm::Map global_funcs); void Add(const GlobalVar& var, const Function & func, bool update = false); - void TryAdd(const GlobalVar& var, const Function & func, bool update=false); void Update(const GlobalVar& var, const Function & func); void Remove(const GlobalVar& var); diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 8c155e3bc1bd..63a42b9d0e3e 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -5,11 +5,7 @@ */ #include #include "tvm/relay/environment.h" -// #include "tvm/relay/alpha_eq.h" -// #include "tvm/relay/debug.h" -// #include "tvm/relay/typeck/typechecker.h" // #include "tvm/relay/util/rang.h" -// #include "tvm/runtime/packed_func_ext.h" namespace tvm { namespace relay { @@ -20,7 +16,7 @@ using namespace tvm::runtime; Environment EnvironmentNode::make( tvm::Map global_funcs) { std::shared_ptr n = std::make_shared(); - n->items = std::move(global_funcs); + n->functions = std::move(global_funcs); return Environment(n); } @@ -35,10 +31,13 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { } } -// // Add a new item to the global environment -// // throws an exception if the item already -// // exists. -// void EnvironmentNode::add(const Item &unchecked_item, bool update) { +/*! \brief Add a new item to the global environment + * \note if the update flag is not set adding a duplicate + * definition will trigger an exception, otherwise we will + * update the definition if and only if it is type compatible. + */ +void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool update) { + throw Error("NYI"); // // Type check the item before we add it to the environment. // auto env = GetRef(this); // Item item = check(env, unchecked_item); @@ -85,14 +84,22 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { // throw EnvError("internal error: unknown item type, unreachable code"); // } // } +} -// void EnvironmentNode::update(const Item &item) { return this->add(item, true); } +void EnvironmentNode::Update(const GlobalVar& var, const Function & func) { + this->Add(var, func, true); +} -// void EnvironmentNode::remove(const GlobalId &id) { this->items.erase(id); } +void EnvironmentNode::Remove(const GlobalVar&) { + // Clarify with @tqchen about how to use COW to do this. + throw Error("NYI"); + // this->items.erase(id); +} Function EnvironmentNode::Lookup(const GlobalVar &var) { - if (items.find(var) != items.end()) { - return items.at(var); + auto func = functions.find(var); + if (func != functions.end()) { + return (*func).second; } else { throw Error(std::string("there is no definition of ") + var->name_hint); } @@ -103,143 +110,51 @@ Function EnvironmentNode::Lookup(const std::string &str) { return this->Lookup(id); } -// inline FileId EnvironmentNode::add_source(std::string file_name, -// std::string source) { -// return this->source_map_.add_source(file_name, source); -// } - -// void EnvironmentNode::report_error(std::string msg, Span sp) { -// this->errors_.push_back(Error(msg, sp)); -// } - -// void EnvironmentNode::display_errors() { -// for (auto err : this->errors_) { -// auto sp = err.sp; -// auto source_file = this->source_map_.GetSource(err.sp->file_id); -// auto file_name = source_file.file_name; -// auto source_at_span = source_file.SourceAt(err.sp, 1); -// std::string error_marker = "error:"; -// auto line_info = -// std::to_string(sp->lineno) + ":" + std::to_string(sp->col_offset); - -// std::cout << rang::style::bold << rang::fg::red << error_marker -// << rang::fg::reset << file_name << ":" << line_info -// << rang::style::reset << " " << source_at_span << std::endl; - -// // Build the cursor. - -// // Fix this code, hardwired to compute alignment of pointer. -// size_t spaces = error_marker.size() + line_info.size() + file_name.size() + -// sp->col_offset - 3; - -// std::string cursor = "~~~~^~~~~"; -// for (size_t i = 0; i < spaces; i++) { -// std::cout << " "; -// } -// std::cout << rang::fg::red << cursor << " " << err.msg << rang::style::reset -// << std::endl; -// } -// } +inline SourceName EnvironmentNode::AddSource(std::string file_name, + std::string source) { + throw Error("need to restore error handling"); + // return this->source_map_.add_source(file_name, source); +} -TVM_REGISTER_API("relay._make.Environment") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = EnvironmentNode::make({}); - }); +void EnvironmentNode::ReportError(std::string msg, Span sp) { + throw Error("need to restore error handling"); + // this->errors_.push_back(Error(msg, sp)); +} -// TVM_REGISTER_API("relay._env.Environment_add") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// Item item = args[1]; -// env->add(item, true); // REMOVE ME -// }); - -// TVM_REGISTER_API("relay._env.Environment_lookup_global") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// GlobalId id = args[1]; -// *ret = env->lookup(id); -// }); - -// TVM_REGISTER_API("relay._env.Environment_lookup_operator") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// OperatorId id = args[1]; -// *ret = env->lookup(id); -// }); - -// // TVM_REGISTER_API("relay._env.Environment_remove_global") -// // .set_body([](TVMArgs args, TVMRetValue *ret) { -// // Environment env = args[0]; -// // GlobalId id = args[1]; -// // env->remove(id); -// // }); - -// TVM_REGISTER_API("relay._env.Environment_global_id") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string str = args[1]; -// *ret = env->global_id(str); -// }); - -// TVM_REGISTER_API("relay._env.Environment_operator_id") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string str = args[1]; -// *ret = env->operator_id(str); -// }); - -// TVM_REGISTER_API("relay._env.Environment_register_shape_ext") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// ShapeExtension ext = args[1]; -// env->register_shape_ext(ext); -// }); - -// TVM_REGISTER_API("relay._env.Environment_register_primitive") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string str = args[1]; -// *ret = env->global_id(str); -// }); - -// TVM_REGISTER_API("relay._env.Environment_add_source") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string file_name = args[1]; -// std::string source_name = args[2]; -// *ret = env->add_source(file_name, source_name); -// }); - -// TVM_REGISTER_API("relay._env.Environment_report_error") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string msg = args[1]; -// Span sp = args[2]; -// env->report_error(msg, sp); -// }); - -// TVM_REGISTER_API("relay._env.Environment_display_errors") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// return env->display_errors(); -// }); - -// TVM_REGISTER_API("relay._env.Environment_get_operators") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// *ret = env->get_operators(); -// }); - -// TVM_REGISTER_API("relay._env.Environment_get_defns") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// *ret = env->get_defns(); -// }); +void EnvironmentNode::DisplayErrors() { + throw Error("need to restore error printing"); + // for (auto err : this->errors_) { + // auto sp = err.sp; + // auto source_file = this->source_map_.GetSource(err.sp->file_id); + // auto file_name = source_file.file_name; + // auto source_at_span = source_file.SourceAt(err.sp, 1); + // std::string error_marker = "error:"; + // auto line_info = + // std::to_string(sp->lineno) + ":" + std::to_string(sp->col_offset); + + // std::cout << rang::style::bold << rang::fg::red << error_marker + // << rang::fg::reset << file_name << ":" << line_info + // << rang::style::reset << " " << source_at_span << std::endl; + + // // Build the cursor. + + // // Fix this code, hardwired to compute alignment of pointer. + // size_t spaces = error_marker.size() + line_info.size() + file_name.size() + + // sp->col_offset - 3; + + // std::string cursor = "~~~~^~~~~"; + // for (size_t i = 0; i < spaces; i++) { + // std::cout << " "; + // } + // std::cout << rang::fg::red << cursor << " " << err.msg << rang::style::reset + // << std::endl; + // } +} TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const EnvironmentNode *node, tvm::IRPrinter *p) { - p->stream << "EnvironmentNode(todo)"; // << node->items << ")"; + p->stream << "EnvironmentNode( " << node->functions << ")"; }); } // namespace relay From eb61a11c42e279d5955503f21cdf309042469e41 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 17:29:30 -0700 Subject: [PATCH 61/88] Fix up Python imports --- python/tvm/relay/{_pass.py => _ir_pass.py} | 2 +- python/tvm/relay/{_pass.pyi => _ir_pass.pyi} | 0 python/tvm/relay/expr.py | 2 +- python/tvm/relay/ir_builder.py | 21 ++++++++++++------- python/tvm/relay/ir_pass.py | 5 +++++ python/tvm/relay/pass.py | 5 ----- src/relay/ir/environment.cc | 5 +++++ src/relay/pass/type_infer.cc | 4 ++-- .../relay/test_tyck_eval_integration.py | 2 +- 9 files changed, 29 insertions(+), 17 deletions(-) rename python/tvm/relay/{_pass.py => _ir_pass.py} (72%) rename python/tvm/relay/{_pass.pyi => _ir_pass.pyi} (100%) create mode 100644 python/tvm/relay/ir_pass.py delete mode 100644 python/tvm/relay/pass.py diff --git a/python/tvm/relay/_pass.py b/python/tvm/relay/_ir_pass.py similarity index 72% rename from python/tvm/relay/_pass.py rename to python/tvm/relay/_ir_pass.py index 052ba6d4a0fb..61fdcfa38c2f 100644 --- a/python/tvm/relay/_pass.py +++ b/python/tvm/relay/_ir_pass.py @@ -2,4 +2,4 @@ from tvm._ffi.function import _init_api -_init_api("relay._pass", __name__) +_init_api("relay._ir_pass", __name__) diff --git a/python/tvm/relay/_pass.pyi b/python/tvm/relay/_ir_pass.pyi similarity index 100% rename from python/tvm/relay/_pass.pyi rename to python/tvm/relay/_ir_pass.pyi diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 41066829e2f3..4f558210fb11 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -6,7 +6,7 @@ from .base import Span, NodeBase, register_relay_node from .type import Type, TypeParam from tvm import expr -from ._type_infer import _get_checked_type +from ._ir_pass import _get_checked_type from . import _make class Expr(NodeBase): diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index a9cb02a19025..b5ca6428c897 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,7 +1,7 @@ from typing import Any import numpy as np import tvm -from .type import FloatType, IntType, BoolType, UIntType, FuncType, TensorType +from .type import FuncType, TensorType from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function from . import op as _op @@ -152,20 +152,27 @@ def get(self): def bool_dtype(): return 'uint1' -def int_dtype(): - return 'uint1' +def int_dtype(bits=32): + return f'int1{bits}' + +def float_dtype(bits=32): + return f'float{bits}' +def uint_dtype(bits=32): + return f'fuint{bits}' + def int_type(bits=32, lanes=1): - return IntType(bits, lanes) + # TODO(@jroesch, @tqchen) How do we set lanes? + return TensorType(tvm.convert([]), int_dtype(bits)) def uint_type(bits=32, lanes=1): - return UIntType(bits, lanes) + return TensorType(tvm.convert([]), uint_dtype(bits)) def float_type(bits=32, lanes=1): - return FloatType(bits, lanes) + return TensorType(tvm.convert([]), float_dtype(bits)) def bool_type(lanes=1): - return BoolType(lanes) + return TensorType(tvm.convert([]), bool_dtype(bits)) def tensor_type(*shape, dtype='float32'): return TensorType(tvm.convert(shape), dtype) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py new file mode 100644 index 000000000000..ad7a68eac392 --- /dev/null +++ b/python/tvm/relay/ir_pass.py @@ -0,0 +1,5 @@ +#pylint: disable-all + +from . import _ir_pass + +check_expr = _ir_pass.check_expr diff --git a/python/tvm/relay/pass.py b/python/tvm/relay/pass.py deleted file mode 100644 index 8c352e58843d..000000000000 --- a/python/tvm/relay/pass.py +++ /dev/null @@ -1,5 +0,0 @@ -#pylint: disable-all - -from . import _pass - -check_expr = _type_infer.check_expr diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 63a42b9d0e3e..cb8afd002c51 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -151,6 +151,11 @@ void EnvironmentNode::DisplayErrors() { // } } +TVM_REGISTER_API("relay._make.Environment") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = EnvironmentNode::make(args[0]); + }); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const EnvironmentNode *node, tvm::IRPrinter *p) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 746323fc6d56..383196f49be9 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -626,7 +626,7 @@ class TypeInferencer : private ExprFunctor { } } - TVM_REGISTER_API("relay._pass.check_expr") + TVM_REGISTER_API("relay._ir_pass.check_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { Environment env = args[0]; Expr e = args[1]; @@ -634,7 +634,7 @@ class TypeInferencer : private ExprFunctor { }); // TODO(@jroesch): put in a better namespace. - TVM_REGISTER_API("relay._pass._get_checked_type") + TVM_REGISTER_API("relay._ir_pass._get_checked_type") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr e = args[0]; *ret = e->checked_type(); diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 6e5b64ee846e..72fd995fd22e 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -1,7 +1,7 @@ """Test that type checker correcly computes types for expressions. """ -from tvm.relay.type_infer import check_expr +from tvm.relay.ir_pass import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, func_type, tensor_type from tvm.relay.env import Environment from tvm.relay.op import log, add From 4dd336d7444884568e00a6f34898e536c564f31c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Aug 2018 22:50:00 -0700 Subject: [PATCH 62/88] Add first pass add broadcast inference --- src/relay/op/type_relations.cc | 33 ++++++++++++++----- .../relay/test_tyck_eval_integration.py | 4 +-- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 56b139731178..d97b8f96e85c 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -24,6 +24,7 @@ TensorType as_ttype(const Type & t) { int to_int(const tvm::Expr & e) { auto imm = e.as(); CHECK(imm); + std::cout << "TYPE: " << imm << imm->type << std::endl; return imm->value; } @@ -60,26 +61,41 @@ static Type ConcreteBroadcast(const TensorType & t1, const TensorType & t2) { } } - Array larger; - Array smaller; + Array larger; + Array smaller; for (int i = 0; i < (full_len - suffix_len); i++) { - // smaller.push_back(tvm::ir::IntImm::make(1)); + smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); } if (sh1.size() < sh2.size()) { - + for (auto sh : sh1) { + smaller.push_back(sh); + } + larger = sh2; } else if (sh1.size() > sh2.size()) { - + for (auto sh : sh1) { + larger.push_back(sh); + } + smaller = sh2; } else { - + larger = sh1; + smaller = sh2; } - for (int i = 0; i < suffix_len - full_len; i++) { + CHECK(larger.size() == smaller.size()); + Array out_shape; + for (int i = 0; i < smaller.size(); i++) { + auto left = smaller[i].as(); + auto right = larger[i].as(); + CHECK(left); + CHECK(right); + int64_t dim = std::max(left->value, right->value); + out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); } - return t1; + return TensorTypeNode::make(out_shape, t1->dtype); } Array BroadcastRel(const Array & types, int num_args) { @@ -89,6 +105,7 @@ Array BroadcastRel(const Array & types, int num_args) { return { t1, t2, ConcreteBroadcast(t1, t2) }; } } + return types; } diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 72fd995fd22e..e928cd5cb76a 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -47,6 +47,6 @@ def test_dual_op(): assert has_type(func.to_func(), func_type([float_type()], float_type())) if __name__ == "__main__": - # test_monomorphic_let() - # test_single_op() + test_monomorphic_let() + test_single_op() test_dual_op() From f73452175d2c7f684382e3028d80c4988f577674 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 16:05:34 -0700 Subject: [PATCH 63/88] Add ability to build and check a global --- include/tvm/relay/environment.h | 4 + python/tvm/relay/env.py | 16 + python/tvm/relay/ir_builder.py | 41 +- src/relay/ir/environment.cc | 125 ++- src/relay/ir/expr.cc | 4 +- src/relay/pass/type_infer.cc | 961 +++++++++--------- .../relay/test_tyck_eval_integration.py | 25 +- 7 files changed, 630 insertions(+), 546 deletions(-) diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index aa41882db46e..5ad7ba8e0010 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -63,6 +63,7 @@ class EnvironmentNode : public RelayNode { void Update(const GlobalVar& var, const Function & func); void Remove(const GlobalVar& var); + /*! \brief Lookup a global function by its variable. */ GlobalVar GetGlobalVar(const std::string& str); /*! \brief Lookup a global function by its variable. */ @@ -70,6 +71,9 @@ class EnvironmentNode : public RelayNode { /*! \brief Lookup a global function by its string name */ Function Lookup(const std::string & s); + + // TODO(@jroesch, @tqchen): what are the semantics here + void Merge(const Environment & env); /*! \brief Add a source fragment to the environment. */ SourceName AddSource(std::string file_name, std::string source); diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 4de5a0c02772..186ee8854c35 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -3,6 +3,7 @@ from typing import Union, List from .base import register_relay_node, NodeBase from . import _make +from . import _env import tvm @register_relay_node @@ -12,3 +13,18 @@ class Environment(NodeBase): """ def __init__(self, funcs) -> None: self.__init_handle_by_constructor__(_make.Environment, funcs) + + def add(self, var, func) -> None: + if isinstance(var, str): + var = _env.Environment_GetGlobalVar(self, var) + + _env.Environment_Add(self, var, func) + + def merge(self, other): + return _env.Environment_Merge(self, other) + + def lookup(self, var): + if isinstance(var, str): + return _env.Environment_Lookup_str(self, var) + else: + return _env.Environment_Lookup(self, var) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index b5ca6428c897..50ebeb1bb12d 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -3,6 +3,7 @@ import tvm from .type import FuncType, TensorType from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function +from .env import Environment from . import op as _op class ExprBuilder(): @@ -83,6 +84,7 @@ def __init__(self): self.scopes = [{}] self.params = [] self.ret_value = None + self.env = Environment({}) def bind(self, name, type, value): @@ -93,6 +95,9 @@ def bind(self, name, type, value): def let(self, name, value, value_type=None): + if isinstance(value, Param): + value = value.var + if not (isinstance(value, Expr) or isinstance(value, ExprBuilder)): value = into_ast(value) @@ -131,8 +136,29 @@ def ret(self, x): raise Exception( "return value already set, a function can only have one return value") - def fn_params(self): - pass + def param(self, name, ty=None): + if not ty: + ty = float_type() + + return Param(LocalVar(name), ty) + + # def params(*args): + # i = 0 + # while i < args.size(): + # arg = args[i] + # if isinstance(arg, str): + + + def decl(self, name: str, *params): + decl_builder = IRBuilder() + + def _on_exit(): + exp, sub_env = decl_builder.get() + self.env.add(name, Function(params, None, exp)) + self.env.merge(sub_env) + + return WithScope(decl_builder, _on_exit) + def get(self): """Get the full program""" @@ -140,14 +166,15 @@ def get(self): scope = self.scopes.pop() if self.bindings: - raise Exception("...") + raise Exception("IRBuilder: binding error") + if self.scopes: - raise Exception("...") + raise Exception("IRBuilder: scoping error") - if not self.ret_value: - raise Exception("...") + if bindings and scope and not self.ret_value: + raise Exception("IRBuilder: no return value set") - return _mk_let(bindings, self.ret_value) + return _mk_let(bindings, self.ret_value), self.env def bool_dtype(): return 'uint1' diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index cb8afd002c51..7861fb58820b 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -4,14 +4,18 @@ * \brief The global environment in Relay. */ #include -#include "tvm/relay/environment.h" +#include +#include +#include +#include +#include "./../pass/resolve.h" // #include "tvm/relay/util/rang.h" namespace tvm { namespace relay { using tvm::IRPrinter; -using namespace tvm::runtime; +using namespace runtime; Environment EnvironmentNode::make( tvm::Map global_funcs) { @@ -37,53 +41,35 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { * update the definition if and only if it is type compatible. */ void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool update) { - throw Error("NYI"); -// // Type check the item before we add it to the environment. -// auto env = GetRef(this); -// Item item = check(env, unchecked_item); - -// if (const OperatorNode *op_node = item.as()) { -// Operator op = GetRef(op_node); -// auto type = op->type; -// if (operators.find(op->id) != operators.end()) { -// if (!update) { -// throw dmlc::Error("already have definition for XXXX."); -// } - -// auto old_type = operators[op->id]->type; - -// if (!alpha_eq(type, old_type)) { -// throw dmlc::Error( -// "Environment#update changes type, not possible in this mode."); -// } - -// operators.insert({op->id, op}); -// } else { -// operators.insert({op->id, op}); -// } -// } else if (const FunctionNode *d = item.as()) { -// auto def = GetRef(d); -// auto type = def->type; -// if (items.find(def->id) != items.end()) { -// if (!update) { -// throw dmlc::Error("already have definition for XXXX."); -// } - -// auto old_type = items[def->id].as()->type; - -// if (!alpha_eq(type, old_type)) { -// throw dmlc::Error( -// "Environment#update changes type, not possible in this mode."); -// } - -// this->items.insert({def->id, def}); -// } else { -// this->items.insert({def->id, def}); -// } -// } else { -// throw EnvError("internal error: unknown item type, unreachable code"); -// } -// } + // Type check the item before we add it to the environment. + auto env = GetRef(this); + Expr checked_expr = InferType(env, func); + + if (const FunctionNode *func_node = checked_expr.as()) { + auto checked_func = GetRef(func_node); + auto type = checked_func->checked_type(); + + CHECK(IsFullyResolved(type)); + + if (functions.find(var) != functions.end()) { + if (!update) { + throw dmlc::Error("already have definition for XXXX."); + } + + auto old_type = functions[var].as()->checked_type(); + + if (!AlphaEqual(type, old_type)) { + throw dmlc::Error( + "Environment#update changes type, not possible in this mode."); + } + + this->functions.Set(var, checked_func); + } else { + this->functions.Set(var, checked_func); + } + } else { + throw Error("internal error: unknown item type, unreachable code"); + } } void EnvironmentNode::Update(const GlobalVar& var, const Function & func) { @@ -110,6 +96,13 @@ Function EnvironmentNode::Lookup(const std::string &str) { return this->Lookup(id); } +void EnvironmentNode::Merge(const Environment & env) { + for (auto pair : env->functions) { + this->functions.Set(pair.first, pair.second); + } +} + + inline SourceName EnvironmentNode::AddSource(std::string file_name, std::string source) { throw Error("need to restore error handling"); @@ -156,6 +149,40 @@ TVM_REGISTER_API("relay._make.Environment") *ret = EnvironmentNode::make(args[0]); }); +TVM_REGISTER_API("relay._env.Environment_Add") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Add(args[1], args[2], false); + }); + +TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + *ret = env->GetGlobalVar(args[1]); + }); + +TVM_REGISTER_API("relay._env.Environment_Lookup") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + GlobalVar var = args[1]; + *ret = env->Lookup(var); + }); + +TVM_REGISTER_API("relay._env.Environment_Lookup_str") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string var_name = args[1]; + auto var = env->GetGlobalVar(var_name); + *ret = env->Lookup(var); + }); + +TVM_REGISTER_API("relay._env.Environment_Merge") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Merge(args[1]); + }); + + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const EnvironmentNode *node, tvm::IRPrinter *p) { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 2b235e8b01ad..47d253e91c21 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -168,8 +168,8 @@ TVM_REGISTER_API("relay._make.Let") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { - p->stream << "LetNode(" << node->var << node->value << node->body - << node->value_type << ")"; + p->stream << "LetNode(" << node->var << ", " << node->value + << ", " << node->body << ", " << node->value_type << ")"; }); If IfNode::make(Expr cond, Expr true_value, Expr false_value) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 383196f49be9..514df129503a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -21,14 +21,14 @@ */ #include +#include #include #include -#include #include "./incomplete_type.h" -#include "./unifier.h" #include "./resolve.h" #include "./type_subst.h" #include "./type_visitor.h" +#include "./unifier.h" // #include "tvm/relay/typeck/kindchecker.h" namespace tvm { @@ -61,9 +61,9 @@ struct TypeContext { struct TypeNormalizer : TypeFVisitor { TypeUnifier unifier; - TypeNormalizer(const TypeUnifier & unifier) : unifier(unifier) {} + TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} - Type VisitType_(const TypeCallNode * ty_call_node) { + Type VisitType_(const TypeCallNode *ty_call_node) { auto ty_call = GetRef(ty_call_node); Array normalized_args; @@ -71,7 +71,7 @@ struct TypeNormalizer : TypeFVisitor { for (auto arg : ty_call->args) { normalized_args.push_back(VisitType(arg)); } - + auto all_concrete = true; for (auto arg : normalized_args) { all_concrete = all_concrete && !arg.as(); @@ -82,7 +82,8 @@ struct TypeNormalizer : TypeFVisitor { } else { if (auto ty_rel_node = ty_call->func.as()) { // NB: we substract 1 for the output argument. - auto new_args = ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); + auto new_args = + ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); CHECK(new_args.size() == normalized_args.size()); tvm::Array final_args; @@ -110,554 +111,544 @@ class TypeInferencer : private ExprFunctor { TypeContext local_stack; public: - Environment env; - TypeUnifier unifier; - - // Should be in header? - template - T with_frame(const std::function & f) { - TypeContext::LocalFrame fr(local_stack); - return f(); - } - - TypeInferencer(); - TypeInferencer(Environment env, TypeUnifier unifier) : env(env), - unifier(unifier) {} explicit TypeInferencer(Environment env); - - CheckedExpr Infer(const Expr & expr); - - FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); - - Type Normalize(const Type & t); - - void report_error(const std::string & msg, Span sp); - [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); - - Type unify(const Type &t1, const Type &t2, Span sp); - Type resolve(const Type &t); - Expr resolve(const Expr &e); - CheckedExpr VisitFunction(const Function & f, bool generalize); - void CheckOp(Op op); - // Defn CheckDefn(Defn def); - private: - CheckedExpr VisitExpr_(const LocalVarNode* op) override; - CheckedExpr VisitExpr_(const GlobalVarNode* op) override; - CheckedExpr VisitExpr_(const ConstantNode* op) override; - CheckedExpr VisitExpr_(const TupleNode* op) override; - CheckedExpr VisitExpr_(const ParamNode* op) override; - CheckedExpr VisitExpr_(const FunctionNode* op) override; - CheckedExpr VisitExpr_(const CallNode* op) override; - CheckedExpr VisitExpr_(const LetNode* op) override; - CheckedExpr VisitExpr_(const IfNode* op) override; - CheckedExpr VisitExpr_(const OpNode* op) override; -}; - - TypeInferencer::TypeInferencer() { - this->env = EnvironmentNode::make({}); - this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); - } - - TypeInferencer::TypeInferencer(Environment env) : env(env) { - this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); - } - - Type TypeInferencer::Normalize(const Type & t) { - auto nt = this->resolve(t); - auto normalizer = TypeNormalizer(this->unifier); - return normalizer.VisitType(nt); - } - - CheckedExpr TypeInferencer::Infer(const Expr &expr) { - RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; - CheckedExpr checked_expr = this->VisitExpr(expr); - RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type << std::endl; - Type final_type = Normalize(checked_expr.type); - RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type << std::endl; - checked_expr.expr->checked_type_ = final_type; - return checked_expr; - } + Environment env; + TypeUnifier unifier; - CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { - auto var = GetRef(op); - return { var, this->local_stack.lookup(var) }; + // Should be in header? + template + T with_frame(const std::function &f) { + TypeContext::LocalFrame fr(local_stack); + return f(); } - CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { - // GlobalVar id = GetRef(op); - // Item item = this->env->lookup(id); + TypeInferencer(); + TypeInferencer(Environment env, TypeUnifier unifier) + : env(env), unifier(unifier) {} + explicit TypeInferencer(Environment env); - // if (const OpNode *op = item.as()) { - // return op->type; - // } + CheckedExpr Infer(const Expr &expr); - // if (const DefnNode *dn = item.as()) { - // Defn def = GetRef(dn); - // return def->type; - // } + FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); - // this->fatal_error("Unhandled case in GlobalId", op->span); - throw Error("hereeee"); - } - - CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { - return { GetRef(const_node), const_node->tensor_type() }; - } + Type Normalize(const Type &t); - CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { - // Tuple pl = GetRef(op); + void report_error(const std::string &msg, Span sp); + [[noreturn]] void fatal_error(const std::string &msg, Span sp); - // std::vector field_types; - // for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) - // { - // field_types.push_back(this->Check(*field)); - // } + Type unify(const Type &t1, const Type &t2, Span sp); + Type resolve(const Type &t); + Expr resolve(const Expr &e); + CheckedExpr VisitFunction(const Function &f, bool generalize); + void CheckOp(Op op); + // Defn CheckDefn(Defn def); + private: + CheckedExpr VisitExpr_(const LocalVarNode *op) override; + CheckedExpr VisitExpr_(const GlobalVarNode *op) override; + CheckedExpr VisitExpr_(const ConstantNode *op) override; + CheckedExpr VisitExpr_(const TupleNode *op) override; + CheckedExpr VisitExpr_(const ParamNode *op) override; + CheckedExpr VisitExpr_(const FunctionNode *op) override; + CheckedExpr VisitExpr_(const CallNode *op) override; + CheckedExpr VisitExpr_(const LetNode *op) override; + CheckedExpr VisitExpr_(const IfNode *op) override; + CheckedExpr VisitExpr_(const OpNode *op) override; +}; - // return TupleTypeNode::make(field_types); - throw Error("TupleNode NYI"); +TypeInferencer::TypeInferencer() { + this->env = EnvironmentNode::make({}); + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +} + +TypeInferencer::TypeInferencer(Environment env) : env(env) { + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +} + +Type TypeInferencer::Normalize(const Type &t) { + auto nt = this->resolve(t); + auto normalizer = TypeNormalizer(this->unifier); + return normalizer.VisitType(nt); +} + +CheckedExpr TypeInferencer::Infer(const Expr &expr) { + RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; + CheckedExpr checked_expr = this->VisitExpr(expr); + RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type + << std::endl; + Type final_type = Normalize(checked_expr.type); + RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type + << std::endl; + checked_expr.expr->checked_type_ = final_type; + return checked_expr; +} + +CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { + auto var = GetRef(op); + return {var, this->local_stack.lookup(var)}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { + GlobalVar var = GetRef(op); + Expr e = this->env->Lookup(var); + return { var, e->checked_type() }; +} + +CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { + return {GetRef(const_node), const_node->tensor_type()}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { + Tuple pl = GetRef(op); + + std::vector field_exprs; + std::vector field_types; + for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) { + auto checked_field = Infer(*field); + field_exprs.push_back(checked_field.expr); + field_types.push_back(checked_field.type); } - CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { - auto rtype = resolve(param->type); - return { ParamNode::make(param->var, rtype), rtype }; - } + return { TupleNode::make(field_exprs), TupleTypeNode::make(field_types) }; +} + +CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { + auto rtype = resolve(param->type); + return {ParamNode::make(param->var, rtype), rtype}; +} + +// // We should probably generalize the subst code. +// struct GeneralizeTypeType : TypeFVisitor { +// Map vars_to_id; +// const TypeUnifier &unifier; + +// GeneralizeTypeType(Map vars_to_id, +// const TypeUnifier &unifier) +// : vars_to_id(vars_to_id), unifier(unifier) {} + +// Type VisitType_(const TypeVarNode *op) override { +// auto repr = unifier->subst(GetRef(op)); +// if (auto tvn = repr.as()) { +// auto ty_var = GetRef(tvn); +// if (vars_to_id.find(ty_var) != vars_to_id.end()) { +// return vars_to_id[ty_var]; +// } else { +// return ty_var; +// } +// } else { +// return this->VisitType(repr); +// } +// } +// }; + +// struct GeneralizeTypeExpr : ExprFVisitor<> { +// Map vars_to_id; +// const TypeUnifier &unifier; + +// GeneralizeTypeExpr(const TypeUnifier &unifier, +// Map vars_to_id) +// : vars_to_id(vars_to_id), unifier(unifier) {} + +// Type VisitType(const Type &t) { +// return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); +// } +// }; + +CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { + // First we add the parameters to the context allowing us to check their + // types. + + // TODO(@jroesch): support polymorphism + + std::vector param_types; + std::vector params; + + return this->with_frame([&]() -> CheckedExpr { + for (auto param : f->params) { + CheckedExpr checked_param = this->Infer(param); + Type arg_type; + param_types.push_back(checked_param.type); + params.push_back(GetRef(checked_param.expr.as())); + this->local_stack.insert(param->var, checked_param.type); + } - // // We should probably generalize the subst code. - // struct GeneralizeTypeType : TypeFVisitor { - // Map vars_to_id; - // const TypeUnifier &unifier; - - // GeneralizeTypeType(Map vars_to_id, - // const TypeUnifier &unifier) - // : vars_to_id(vars_to_id), unifier(unifier) {} - - // Type VisitType_(const TypeVarNode *op) override { - // auto repr = unifier->subst(GetRef(op)); - // if (auto tvn = repr.as()) { - // auto ty_var = GetRef(tvn); - // if (vars_to_id.find(ty_var) != vars_to_id.end()) { - // return vars_to_id[ty_var]; - // } else { - // return ty_var; - // } + auto checked_body = this->Infer(f->body); + auto inferred_rtype = checked_body.type; + auto annotated_rtype = resolve(f->ret_type); + + auto unified_rtype = this->unify(inferred_rtype, annotated_rtype, f->span); + + return {FunctionNode::make(params, unified_rtype, checked_body.expr, {}), + FuncTypeNode::make(param_types, unified_rtype, {}, {})}; + }); + + // // typecheck body and ensure that it matches stated return type + // // TODO(sslyu): should the unified return type override the annotated + // one? Type checked_return = this->Check(f->body); Type ret_type = + // resolve(f->ret_type); Type unified = + // this->unify(simple_eval_shape(ret_type), + // simple_eval_shape(checked_return), f->span); + // return TypeArrowNode::make(arg_types, unified); + // }); + // if (generalize) { + // auto free_vars = free_type_vars(resolve(fn_type)); + // std::set dedup_free_vars; + + // for (auto free_var : free_vars) { + // auto repr = this->unifier->subst(free_var); + // if (auto new_free_var_node = repr.as()) { + // dedup_free_vars.insert(GetRef(new_free_var_node)); // } else { - // return this->VisitType(repr); + // // debug(repr); + // throw dmlc::Error( + // "internal error: this list should only contain type var + // nodes"); // } // } - // }; - // struct GeneralizeTypeExpr : ExprFVisitor<> { // Map vars_to_id; - // const TypeUnifier &unifier; - // GeneralizeTypeExpr(const TypeUnifier &unifier, - // Map vars_to_id) - // : vars_to_id(vars_to_id), unifier(unifier) {} - - // Type VisitType(const Type &t) { - // return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); + // GenFresh gf; + // for (auto free_var : dedup_free_vars) { + // vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); // } - // }; - - CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { - // First we add the parameters to the context allowing us to check their - // types. - - // TODO(@jroesch): support polymorphism - - std::vector param_types; - std::vector params; - - return this->with_frame([&]() -> CheckedExpr { - for (auto param : f->params) { - CheckedExpr checked_param = this->Infer(param); - Type arg_type; - param_types.push_back(checked_param.type); - params.push_back(GetRef(checked_param.expr.as())); - this->local_stack.insert(param->var, checked_param.type); - } - - auto checked_body = this->Infer(f->body); - auto inferred_rtype = checked_body.type; - auto annotated_rtype = resolve(f->ret_type); - - auto unified_rtype = this->unify(inferred_rtype, annotated_rtype, f->span); - - return { FunctionNode::make(params, unified_rtype, checked_body.expr, {}), - FuncTypeNode::make(param_types, unified_rtype, {}, {}) }; - }); - - // // typecheck body and ensure that it matches stated return type - // // TODO(sslyu): should the unified return type override the annotated - // one? Type checked_return = this->Check(f->body); Type ret_type = - // resolve(f->ret_type); Type unified = - // this->unify(simple_eval_shape(ret_type), - // simple_eval_shape(checked_return), f->span); - // return TypeArrowNode::make(arg_types, unified); - // }); - // if (generalize) { - // auto free_vars = free_type_vars(resolve(fn_type)); - // std::set dedup_free_vars; - - // for (auto free_var : free_vars) { - // auto repr = this->unifier->subst(free_var); - // if (auto new_free_var_node = repr.as()) { - // dedup_free_vars.insert(GetRef(new_free_var_node)); - // } else { - // // debug(repr); - // throw dmlc::Error( - // "internal error: this list should only contain type var - // nodes"); - // } - // } - - // Map vars_to_id; - - // GenFresh gf; - // for (auto free_var : dedup_free_vars) { - // vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); - // } - - // fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); - // for (std::pair pair : vars_to_id) { - // // NB: In generalization we want to find type variables with - // // *no constraints* on them, and convert them to universally - // quantified - // // variables. - // // - // // i.e the program can be abstracted over the details of *that* type. - - // // For example a program that works irrespective of shape or - // datatype. - - // // In order to do this we find the set of free type variables in the - // // term, and then unify them with the fresh type ids we generate. - // // - // // Remember importantly these type variables still may appear in many - // // places in the program including both types and expressions. - - // // Our method for resolving these is to unify them with the variables - // // as we build the new quanitifer, changing from a program with - // "holes" - // // to one that is properly abstracted over. - - // // Finally later on we can iterate over the whole term and change - // from - // // type variables to these type ids. - // this->unify(pair.first, pair.second, pair.second->span); - // fn_type = TypeQuantifierNode::make(pair.second, fn_type); - // } - // } else { - // for (auto i = f->ty_params.size(); i > 0; i--) { - // auto ty_param = f->ty_params[i - 1]; - // auto ty_param_node = ty_param.as(); - // if (!ty_param_node) { - // throw dmlc::Error("internal error should be TypeParam"); - // } - // auto fresh_tid = - // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); - // fn_type = - // TypeSubst(fn_type, GetRef(ty_param_node), fresh_tid); - // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); - // } - // } - - // return fn_type; - } - CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { - return this->VisitFunction(GetRef(op), false); - } + // fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); + // for (std::pair pair : vars_to_id) { + // // NB: In generalization we want to find type variables with + // // *no constraints* on them, and convert them to universally + // quantified + // // variables. + // // + // // i.e the program can be abstracted over the details of *that* type. + + // // For example a program that works irrespective of shape or + // datatype. + + // // In order to do this we find the set of free type variables in the + // // term, and then unify them with the fresh type ids we generate. + // // + // // Remember importantly these type variables still may appear in many + // // places in the program including both types and expressions. + + // // Our method for resolving these is to unify them with the variables + // // as we build the new quanitifer, changing from a program with + // "holes" + // // to one that is properly abstracted over. + + // // Finally later on we can iterate over the whole term and change + // from + // // type variables to these type ids. + // this->unify(pair.first, pair.second, pair.second->span); + // fn_type = TypeQuantifierNode::make(pair.second, fn_type); + // } + // } else { + // for (auto i = f->ty_params.size(); i > 0; i--) { + // auto ty_param = f->ty_params[i - 1]; + // auto ty_param_node = ty_param.as(); + // if (!ty_param_node) { + // throw dmlc::Error("internal error should be TypeParam"); + // } + // auto fresh_tid = + // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); + // fn_type = + // TypeSubst(fn_type, GetRef(ty_param_node), fresh_tid); + // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); + // } + // } - FuncType TypeInferencer::instantiate(FuncType fn_ty, tvm::Array &ty_args) { - tvm::Map subst_map; + // return fn_type; +} - // Build a subsitituion map up from the function type and type arguments. - // Eventually allow the type vars to be passed in. - for (auto ty_param : fn_ty->type_params) { - IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); - this->unifier->insert(fresh); - ty_args.push_back(fresh); - subst_map.Set(ty_param, fresh); - } +CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { + return this->VisitFunction(GetRef(op), false); +} - Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); - inst_ty = TypeSubst(fn_ty, subst_map); +FuncType TypeInferencer::instantiate(FuncType fn_ty, + tvm::Array &ty_args) { + tvm::Map subst_map; - // if (!check_kind(t)) { - // this->fatal_error("Kind rules broken when instantiating type - // variables", - // t->span); - // } - - return GetRef(inst_ty.as()); + // Build a subsitituion map up from the function type and type arguments. + // Eventually allow the type vars to be passed in. + for (auto ty_param : fn_ty->type_params) { + IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); + this->unifier->insert(fresh); + ty_args.push_back(fresh); + subst_map.Set(ty_param, fresh); } - CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { - Call c = GetRef(op); + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + inst_ty = TypeSubst(fn_ty, subst_map); - auto checked_op = this->Infer(c->op); + // if (!check_kind(t)) { + // this->fatal_error("Kind rules broken when instantiating type + // variables", + // t->span); + // } - RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl - << "fn_ty=" << checked_op.type << std::endl; + return GetRef(inst_ty.as()); +} +CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { + Call c = GetRef(op); - auto fn_ty_node = checked_op.type.as(); + auto checked_op = this->Infer(c->op); - if (!fn_ty_node) { - this->fatal_error("only expressions with function types can be called", c->op->span); - } + RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl + << "fn_ty=" << checked_op.type << std::endl; - // We now have a function type. - FuncType fn_ty = GetRef(fn_ty_node); - - tvm::Array ty_args; - if (ty_args.size() != 0) { - throw Error("found manually suplied type args, not supported"); - } + auto fn_ty_node = checked_op.type.as(); - fn_ty = instantiate(fn_ty, ty_args); + if (!fn_ty_node) { + this->fatal_error("only expressions with function types can be called", + c->op->span); + } - std::vector arg_types; - std::vector checked_args; + // We now have a function type. + FuncType fn_ty = GetRef(fn_ty_node); - for (auto arg : c->args) { - auto checked_arg = this->Infer(arg); - arg_types.push_back(checked_arg.type); - checked_args.push_back(checked_arg.expr); - } + tvm::Array ty_args; + if (ty_args.size() != 0) { + throw Error("found manually suplied type args, not supported"); + } - auto type_arity = fn_ty->arg_types.size(); - auto number_of_args = arg_types.size(); + fn_ty = instantiate(fn_ty, ty_args); - if (type_arity != number_of_args) { - if (type_arity < number_of_args) { - this->fatal_error("the function is provided too many arguments", - c->span); - } else { - this->fatal_error("the function is provided too few arguments", - c->span); - } - } + std::vector arg_types; + std::vector checked_args; - for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); - } + for (auto arg : c->args) { + auto checked_arg = this->Infer(arg); + arg_types.push_back(checked_arg.type); + checked_args.push_back(checked_arg.expr); + } - // After we unify the arguments we should know more about the type - // arguments, let's run a quick pass over them to find new - // representatives. + auto type_arity = fn_ty->arg_types.size(); + auto number_of_args = arg_types.size(); - for (size_t i = 0; i < ty_args.size(); i++) { - ty_args.Set(i, this->unifier->subst(ty_args[i])); + if (type_arity != number_of_args) { + if (type_arity < number_of_args) { + this->fatal_error("the function is provided too many arguments", c->span); + } else { + this->fatal_error("the function is provided too few arguments", c->span); } + } - auto new_call = CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); - - return { new_call, fn_ty->ret_type }; + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { + this->unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); } - CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { - Let let = GetRef(op); + // After we unify the arguments we should know more about the type + // arguments, let's run a quick pass over them to find new + // representatives. - CheckedExpr checked_value; - Type annotated_ty = resolve(let->value_type); + for (size_t i = 0; i < ty_args.size(); i++) { + ty_args.Set(i, this->unifier->subst(ty_args[i])); + } + auto new_call = + CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); - // If we are let-defining a function, we want to be able to - // recursively name the function in order to support recursive - // local definitions. - if (let->value.as()) { - with_frame([&]() { - local_stack.insert(let->var, annotated_ty); - checked_value = Infer(let->value); - }); - } else { - checked_value = Infer(let->value); - } + return {new_call, fn_ty->ret_type}; +} - Type unified_ty = - this->unify(checked_value.type, annotated_ty, let->span); +CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { + Let let = GetRef(op); - // Update type context with unified type now that we have - // solved this equation. - local_stack.insert(let->var, unified_ty); + CheckedExpr checked_value; + Type annotated_ty = resolve(let->value_type); - auto checked_body = with_frame([&]() { - local_stack.insert(let->var, unified_ty); - return Infer(let->body); + // If we are let-defining a function, we want to be able to + // recursively name the function in order to support recursive + // local definitions. + if (let->value.as()) { + with_frame([&]() { + local_stack.insert(let->var, annotated_ty); + checked_value = Infer(let->value); }); - - auto checked_let = LetNode::make( - let->var, - checked_value.expr, - checked_body.expr, - let->value_type); - - return { checked_let, checked_body.type }; + } else { + checked_value = Infer(let->value); } - CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { - If ifn = GetRef(op); - - // Ensure the type of the guard is of Tensor[Bool, ()], - // that is a rank-0 boolean tensor. - auto checked_cond = this->Infer(ifn->cond); - auto cond_type = checked_cond.type; - - if (const TensorTypeNode *tt_node = cond_type.as()) { - TensorType tt = GetRef(tt_node); - if (tt->dtype.is_bool() && tt->shape.size() == 0) { - auto checked_true = this->Infer(ifn->true_value); - auto checked_false = this->Infer(ifn->false_value); - auto unified_type = this->unify(checked_true.type, checked_false.type, ifn->span); - auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, checked_false.expr); - return { checked_if, unified_type }; - } - } + Type unified_ty = this->unify(checked_value.type, annotated_ty, let->span); - this->fatal_error("if-then-else guard must be a rank-0 boolean tensor", - ifn->cond->span); - } + // Update type context with unified type now that we have + // solved this equation. + local_stack.insert(let->var, unified_ty); - CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { - auto op = GetRef(op_node); - return { op, op->op_type }; - } - - Type TypeInferencer::resolve(const Type &t) { - if (t.defined()) { - return ::tvm::relay::Resolve(this->unifier, t); - } else { - return IncompleteTypeNode::make(TypeParamNode::Kind::kType); + auto checked_body = with_frame([&]() { + local_stack.insert(let->var, unified_ty); + return Infer(let->body); + }); + + auto checked_let = LetNode::make(let->var, checked_value.expr, + checked_body.expr, let->value_type); + + return {checked_let, checked_body.type}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { + If ifn = GetRef(op); + + // Ensure the type of the guard is of Tensor[Bool, ()], + // that is a rank-0 boolean tensor. + auto checked_cond = this->Infer(ifn->cond); + auto cond_type = checked_cond.type; + + if (const TensorTypeNode *tt_node = cond_type.as()) { + TensorType tt = GetRef(tt_node); + if (tt->dtype.is_bool() && tt->shape.size() == 0) { + auto checked_true = this->Infer(ifn->true_value); + auto checked_false = this->Infer(ifn->false_value); + auto unified_type = + this->unify(checked_true.type, checked_false.type, ifn->span); + auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, + checked_false.expr); + return {checked_if, unified_type}; } } - Expr TypeInferencer::resolve(const Expr &e) { - CHECK(e.defined()); - return ::tvm::relay::Resolve(this->unifier, e); - } + this->fatal_error("if-then-else guard must be a rank-0 boolean tensor", + ifn->cond->span); +} - void TypeInferencer::CheckOp(Op op) { - throw Error("NYI"); - // if (!check_kind(op->type)) { - // report_error("the type of the operator is ill formed", op->type->span); - // } +CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { + auto op = GetRef(op_node); + return {op, op->op_type}; +} - // // Fix me - // return op; +Type TypeInferencer::resolve(const Type &t) { + if (t.defined()) { + return ::tvm::relay::Resolve(this->unifier, t); + } else { + return IncompleteTypeNode::make(TypeParamNode::Kind::kType); } +} - // Defn TypeInferencer::CheckDefn(Defn defn) { - // // This is to handle recursion, but we need to speculatively - // // put it in env, then remove it. - // env->items.insert({defn->id, defn}); - - // Type expected_ty = this->resolve(defn->type); - - // Expr body = defn->body; - - // auto checked_ty = Check(body); - - // try { - // Type uret_type = unify(expected_ty, checked_ty, defn->body->span); - // CHECK(is_fully_resolved(uret_type)); - // // Now let's clean up our work from earlier. - // env->items.erase(defn->id); - // return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); - // } catch (const UnificationError& err) { - // std::string msg = std::string("mismatch between `") + - // PrintType(env, expected_ty, WrapWidth(40)) + "` and - // `" + PrintType(env, checked_ty, WrapWidth(40)) + - // "`"; - // fatal_error(msg, defn->span); - // } - // } +Expr TypeInferencer::resolve(const Expr &e) { + CHECK(e.defined()); + return ::tvm::relay::Resolve(this->unifier, e); +} - Expr Infer(const Environment &env, const Expr &e) { - TypeInferencer ti(env); - auto checked_expr = ti.Infer(e); - return checked_expr.expr; - } - - // Item Check(const Environment &env, const Item &i) { - // TypeInferencer tc(env); - - // try { - // if (const DefnNode *defn = i.as()) { - // return tc.CheckDefn(GetRef(defn)); - // } else if (const OpNode *op_node = i.as()) { - // return tc.CheckOp(GetRef(op_node)); - // } else { - // throw dmlc::Error("internal error: unknown Item type"); - // } - // } catch (const FatalTypeError &err) { - // env->display_errors(); - // throw dmlc::Error( - // "We encountered a fatal error while type checking your program, - // please " "read above for more details."); - // } +void TypeInferencer::CheckOp(Op op) { + throw Error("NYI"); + // if (!check_kind(op->type)) { + // report_error("the type of the operator is ill formed", op->type->span); // } - inline void TypeInferencer::report_error(const std::string &msg, Span sp) { - // this->env->report_error(msg, sp); - } - - void TypeInferencer::fatal_error(const std::string &msg, Span sp) { - // this->env->report_error(msg, sp); - throw FatalTypeError( - "internal error: this exception should" - "be handled and errors reported with Environment::display_errors\n" + - msg); + // // Fix me + // return op; +} + +// Defn TypeInferencer::CheckDefn(Defn defn) { +// // This is to handle recursion, but we need to speculatively +// // put it in env, then remove it. +// env->items.insert({defn->id, defn}); + +// Type expected_ty = this->resolve(defn->type); + +// Expr body = defn->body; + +// auto checked_ty = Check(body); + +// try { +// Type uret_type = unify(expected_ty, checked_ty, defn->body->span); +// CHECK(is_fully_resolved(uret_type)); +// // Now let's clean up our work from earlier. +// env->items.erase(defn->id); +// return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); +// } catch (const UnificationError& err) { +// std::string msg = std::string("mismatch between `") + +// PrintType(env, expected_ty, WrapWidth(40)) + "` and +// `" + PrintType(env, checked_ty, WrapWidth(40)) + +// "`"; +// fatal_error(msg, defn->span); +// } +// } + +Expr InferType(const Environment &env, const Expr &e) { + TypeInferencer ti(env); + auto checked_expr = ti.Infer(e); + return checked_expr.expr; +} + +// Item Check(const Environment &env, const Item &i) { +// TypeInferencer tc(env); + +// try { +// if (const DefnNode *defn = i.as()) { +// return tc.CheckDefn(GetRef(defn)); +// } else if (const OpNode *op_node = i.as()) { +// return tc.CheckOp(GetRef(op_node)); +// } else { +// throw dmlc::Error("internal error: unknown Item type"); +// } +// } catch (const FatalTypeError &err) { +// env->display_errors(); +// throw dmlc::Error( +// "We encountered a fatal error while type checking your program, +// please " "read above for more details."); +// } +// } + +inline void TypeInferencer::report_error(const std::string &msg, Span sp) { + // this->env->report_error(msg, sp); +} + +void TypeInferencer::fatal_error(const std::string &msg, Span sp) { + // this->env->report_error(msg, sp); + throw FatalTypeError( + "internal error: this exception should" + "be handled and errors reported with Environment::display_errors\n" + + msg); +} + +Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { + try { + return Normalize(this->unifier->unify(t1, t2)); + } catch (const dmlc::Error &e) { + std::stringstream ss; + ss << "Error unifying `"; + ss << t1; + // ss << PrintType(env, t1, WrapWidth(40)); + ss << "` and `"; + ss << t2; + // ss << PrintType(env, t2, WrapWidth(40)); + ss << "`: " << e.what(); + this->fatal_error(ss.str(), sp); } +} - Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { - try { - return Normalize(this->unifier->unify(t1, t2)); - } catch (const dmlc::Error &e) { - std::stringstream ss; - ss << "Error unifying `"; - ss << t1; - // ss << PrintType(env, t1, WrapWidth(40)); - ss << "` and `"; - ss << t2; - // ss << PrintType(env, t2, WrapWidth(40)); - ss << "`: " << e.what(); - this->fatal_error(ss.str(), sp); - } - } +TVM_REGISTER_API("relay._ir_pass.check_expr") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + Expr e = args[1]; + *ret = InferType(env, e); + }); - TVM_REGISTER_API("relay._ir_pass.check_expr") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - Expr e = args[1]; - *ret = Infer(env, e); - }); - - // TODO(@jroesch): put in a better namespace. - TVM_REGISTER_API("relay._ir_pass._get_checked_type") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Expr e = args[0]; - *ret = e->checked_type(); - }); - - IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { - std::shared_ptr n = - std::make_shared(); - n->kind = std::move(kind); - return IncompleteType(n); - } +// TODO(@jroesch): put in a better namespace. +TVM_REGISTER_API("relay._ir_pass._get_checked_type") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Expr e = args[0]; + *ret = e->checked_type(); + }); - TVM_REGISTER_API("relay._make.IncompleteType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - int kind = args[0]; - *ret = IncompleteTypeNode::make(static_cast(kind)); - }); +IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { + std::shared_ptr n = + std::make_shared(); + n->kind = std::move(kind); + return IncompleteType(n); +} + +TVM_REGISTER_API("relay._make.IncompleteType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[0]; + *ret = IncompleteTypeNode::make(static_cast(kind)); + }); - TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const IncompleteTypeNode *node, - tvm::IRPrinter *p) { - p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; - }); +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IncompleteTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); } // namespace relay -} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index e928cd5cb76a..bf95992d952f 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -6,12 +6,14 @@ from tvm.relay.env import Environment from tvm.relay.op import log, add -def has_type(expr, typ): - env = Environment({}) +def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) - import pdb; pdb.set_trace() return checked_expr.checked_type() == typ +def decl_has_type(env, name, typ): + func = env.lookup(name) + return func.checked_type() == typ + def test_monomorphic_let(): "Program: let x = 1; return x" b = IRBuilder() @@ -46,7 +48,24 @@ def test_dual_op(): b.ret(t2) assert has_type(func.to_func(), func_type([float_type()], float_type())) + +def test_decl(): + """Program: + def f(x : Tensor[f32, (10, 10)]) { + let lx = log(x); + return lx; + } + """ + b = IRBuilder() + x = b.param('x') + with b.decl('f', x) as d: + lx = d.let('lx', log(x)) + d.ret(lx) + _, env = b.get() + assert decl_has_type(env, 'f', func_type([float_type()], float_type())) + if __name__ == "__main__": test_monomorphic_let() test_single_op() test_dual_op() + test_decl() From df5b4c6998b96bccc6edfdc12228f2491ee21d3c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 17:01:11 -0700 Subject: [PATCH 64/88] Address first round of CR comments --- include/tvm/relay/base.h | 2 +- include/tvm/relay/environment.h | 12 +++--- include/tvm/relay/error.h | 4 ++ include/tvm/relay/expr.h | 2 - include/tvm/relay/pass/alpha_eq.h | 42 +++++++++++++++++-- include/tvm/relay/pass/type_infer.h | 6 +-- include/tvm/relay/source_map.h | 2 +- src/relay/pass/unifier.h | 5 ++- src/relay/source_map.cc | 4 +- .../relay/test_tyck_eval_integration.py | 2 +- 10 files changed, 59 insertions(+), 22 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 092f5ceb8fc3..e78c4b28e9ca 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -24,7 +24,7 @@ namespace relay { /*! * \brief we always used NodeRef for referencing nodes. * - * By default, NodePtr is a std::shared_ptr of node + * By default, NodeRef is a std::shared_ptr of node */ using NodeRef = tvm::NodeRef; diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 5ad7ba8e0010..ca5b8ac90df4 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -7,13 +7,13 @@ #ifndef TVM_RELAY_ENVIRONMENT_H_ #define TVM_RELAY_ENVIRONMENT_H_ +#include +#include +#include +#include +#include #include #include -#include "./expr.h" -#include "./type.h" -#include "./op.h" -#include "./error.h" -#include "tvm/relay/source_map.h" namespace tvm { namespace relay { @@ -28,7 +28,7 @@ struct Environment; * It contains all global functions, and configuration * options. * - * Many operations require acess to the global + * Many operations require access to the global * Environment. We pass the Environment by value * in a functional style as an explicit argument, * but we will mutate the Environment while optimizing diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 4f6a27d209c8..433c08abfd58 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -16,6 +16,10 @@ struct Error : dmlc::Error { Error(std::string msg) : dmlc::Error(msg) {} }; +struct InternalError : Error { + InternalError(std::string msg) : Error(msg) {} +}; + struct SpannedError { std::string msg; Span sp; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index ddac633f9d09..8ea3980dad46 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -224,8 +224,6 @@ class FunctionNode : public ExprNode { RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); -using Attrs = tvm::Attrs; - /*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h index 9f3c2138a440..51b5b4dd8b70 100644 --- a/include/tvm/relay/pass/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -6,14 +6,48 @@ #ifndef TVM_RELAY_ALPHA_EQ_H_ #define TVM_RELAY_ALPHA_EQ_H_ -#include "../type.h" -#include "../expr.h" +#include +#include namespace tvm { namespace relay { -bool AlphaEqual(const Expr & e1, const Expr & e2); -bool AlphaEqual(const Type & t1, const Type & t2); +/*! \brief Compare two expressions for structural equivalence. + + This comparsion operator respects scoping and compares + expressions without regard to variable choice. + + For example: `let x = 1 in x` is equal to `let y = 1 in y`. + + See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + for more details. + + \param e1 The left hand expression. + \param e2 The right hand expression. + + \return true if equal, otherwise false + +*/ +bool AlphaEqual(const Expr& e1, const Expr& e2); + +/*! \brief Compare two types for structural equivalence. + + This comparsion operator respects scoping and compares + expressions without regard to variable choice. + + For example: `forall s, Tensor[f32, s]` is equal to + `forall w, Tensor[f32, w]`. + + See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + for more details. + + \param t1 The left hand type. + \param t2 The right hand type. + + \return true if equal, otherwise false + +*/ +bool AlphaEqual(const Type& t1, const Type& t2); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/pass/type_infer.h b/include/tvm/relay/pass/type_infer.h index 2b860a5e89ef..a75eac6cc0da 100644 --- a/include/tvm/relay/pass/type_infer.h +++ b/include/tvm/relay/pass/type_infer.h @@ -6,8 +6,8 @@ * The pass produces a new expression with its checked_type * field populated and incomplete types resolved. */ -#ifndef TVM_RELAY_PASS_TYPECHECKER_H_ -#define TVM_RELAY_PASS_TYPECHECKER_H_ +#ifndef TVM_RELAY_PASS_TYPE_INFER_H_ +#define TVM_RELAY_PASS_TYPE_INFER_H_ #include "tvm/relay/expr.h" #include "tvm/relay/environment.h" @@ -22,4 +22,4 @@ Op CheckOp(const Environment & env, const Op & op); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_TYPECHECKER_H_ +#endif // TVM_RELAY_PASS_TYPE_INFER_H_ diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h index 71bf93aa1ed9..a4dbc20b30ff 100644 --- a/include/tvm/relay/source_map.h +++ b/include/tvm/relay/source_map.h @@ -18,7 +18,7 @@ struct SourceFragment { std::string file_name; std::vector source_lines; - SourceFragment(std::string file_name, std::string source); + SourceFragment(const std::string& file_name, const std::string& source); SourceFragment(const SourceFragment& sf) { this->file_name = sf.file_name; diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index 5a4adea5c44e..64485768c2f0 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -61,8 +61,9 @@ class UnionFind : public NodeRef { UnionFind() {} explicit UnionFind(std::shared_ptr p) : NodeRef(p) {} - // no const so that union find can be mutable as a member of unifier - inline UnionFindNode* operator->() const { + // The union find structure is mutable so we do not use the standard macros + // and expose the pointer via `->`. + UnionFindNode* operator->() const { return static_cast(node_.get()); } diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc index a1b3627bccc8..d784c7946954 100644 --- a/src/relay/source_map.cc +++ b/src/relay/source_map.cc @@ -14,7 +14,7 @@ namespace relay { using tvm::IRPrinter; using namespace tvm::runtime; -SourceFragment::SourceFragment(std::string file_name, std::string source) +SourceFragment::SourceFragment(const std::string& file_name, const std::string& source) : file_name(file_name), source_lines({}) { RELAY_LOG(INFO)<< "SourceFragment::SourceFragment source=" << source << std::endl; std::stringstream source_stream; @@ -28,7 +28,7 @@ SourceFragment::SourceFragment(std::string file_name, std::string source) } } -std::string SourceFragment::SourceAt(Span sp, int max_lines) { +std::string SourceFragment::SourceAt(Span sp, int max_lines = 1) { std::stringstream out; // We need to move from 1 based indexing to zero based indexing. diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index bf95992d952f..7d42448a175b 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -20,7 +20,7 @@ def test_monomorphic_let(): x = b.let('x', 1.0, value_type=float_type(64)) b.ret(x) - prog = b.get() + prog, _ = b.get() assert has_type(prog, float_type(64)) def test_single_op(): From b05842cbd325f9bd3deb618a6061498b79ba3f91 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 17:01:45 -0700 Subject: [PATCH 65/88] Add skeleton for kind checker --- include/tvm/relay/pass.h | 20 +++++++++++++++-- src/relay/pass/kind_check.cc | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 src/relay/pass/kind_check.cc diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 89f3dd48fc31..0d73ea2ce976 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -6,8 +6,8 @@ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ -#include "tvm/relay/expr.h" -#include "tvm/relay/environment.h" +#include +#include namespace tvm { namespace relay { @@ -18,6 +18,22 @@ namespace relay { */ Expr InferType(const Environment & env, const Expr & e); +/*! + * \brief Check that types are well formed by applying "kinding rules". + * + * This pass ensures we do not do things that violate the design of the + * type system when writing down types. + * + * For example tensors are not allowed to contain functions in Relay. + * + * We check this by ensuring the `dtype` field of a Tensor always contains + * a data type such as `int`, `float`, `uint`. + * + * \param env The global environment. + * \param t The type to check. + */ +void KindCheck(const Environment& env, const Type& t); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_TYPECHECKER_H_ \ No newline at end of file diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc new file mode 100644 index 000000000000..65b0b087131c --- /dev/null +++ b/src/relay/pass/kind_check.cc @@ -0,0 +1,42 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file kindchecker.cc + * + * \brief Check that types are well formed by applying "kinding rules". + * + * This pass ensures we do not do things that violate the design of the + * type system when writing down types. + * + * For example tensors are not allowed to contain functions in Relay. + * + * We check this by ensuring the `dtype` field of a Tensor always + * contains a data type such as `int`, `float`, `uint`. + */ +#include +#include +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +struct KindChecker : TypeVisitor<> { + bool valid; + + KindChecker() : valid(true) {} + + bool Check(const Type &t) { + this->VisitType(t); + return valid; + } +}; + +bool KindCheck(const Type &t) { + KindChecker kc; + return kc.Check(t); +} + +} // namespace relay +} // namespace tvm \ No newline at end of file From 907c7dfda8476227e55cc81262ba9003fc92ece5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 17:08:58 -0700 Subject: [PATCH 66/88] Tweak docs in pass.h --- include/tvm/relay/pass.h | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 0d73ea2ce976..6d9761daa925 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -6,17 +6,26 @@ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ -#include #include +#include namespace tvm { namespace relay { -/*! The result of type checking an expression is a new expression - * with unambigous type information filled in, as well as it's - * checked type field populated with the result type. +/*! \brief Infer the type of an expression with the provided environment. + * + * The result of type checking is a new expression with unambigous + * type information filled in, as well as it's checked type field + * populated with the result type. + * + * \param env The environment used for global settings and referencing + * global functions. + * + * \param e The expression to type check. + * + * \return A type checked expression with its checked_type field populated. */ -Expr InferType(const Environment & env, const Expr & e); +Expr InferType(const Environment& env, const Expr& e); /*! * \brief Check that types are well formed by applying "kinding rules". @@ -28,7 +37,7 @@ Expr InferType(const Environment & env, const Expr & e); * * We check this by ensuring the `dtype` field of a Tensor always contains * a data type such as `int`, `float`, `uint`. - * + * * \param env The global environment. * \param t The type to check. */ From e06c71f9e0ee33a68ca1596a9553f100b03b7770 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 17:17:39 -0700 Subject: [PATCH 67/88] Refactor kind_check.{h, cc} --- include/tvm/relay/pass.h | 3 ++- include/tvm/relay/pass/type_infer.h | 25 ------------------------- src/relay/ir/environment.cc | 1 - src/relay/pass/kind_check.cc | 2 +- src/relay/pass/type_infer.cc | 9 ++------- 5 files changed, 5 insertions(+), 35 deletions(-) delete mode 100644 include/tvm/relay/pass/type_infer.h diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 6d9761daa925..738c6033147c 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -40,8 +40,9 @@ Expr InferType(const Environment& env, const Expr& e); * * \param env The global environment. * \param t The type to check. + * \return true if the rules are satisified otherwise false */ -void KindCheck(const Environment& env, const Type& t); +bool KindCheck(const Environment& env, const Type& t); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/pass/type_infer.h b/include/tvm/relay/pass/type_infer.h deleted file mode 100644 index a75eac6cc0da..000000000000 --- a/include/tvm/relay/pass/type_infer.h +++ /dev/null @@ -1,25 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/pass/type_infer.h - * \brief Perform type inference and checking on Relay programs. - * - * The pass produces a new expression with its checked_type - * field populated and incomplete types resolved. - */ -#ifndef TVM_RELAY_PASS_TYPE_INFER_H_ -#define TVM_RELAY_PASS_TYPE_INFER_H_ - -#include "tvm/relay/expr.h" -#include "tvm/relay/environment.h" - -namespace tvm { -namespace relay { - -/*! \brief Ensures that an operator is well-formed with respect - * to Relay's type system. - */ -Op CheckOp(const Environment & env, const Op & op); - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_PASS_TYPE_INFER_H_ diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 7861fb58820b..4c17f7cdbc89 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -7,7 +7,6 @@ #include #include #include -#include #include "./../pass/resolve.h" // #include "tvm/relay/util/rang.h" diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 65b0b087131c..c3823c8c3a35 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -33,7 +33,7 @@ struct KindChecker : TypeVisitor<> { } }; -bool KindCheck(const Type &t) { +bool KindCheck(const Environment& env, const Type &t) { KindChecker kc; return kc.Check(t); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 514df129503a..894139c10b53 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -23,13 +23,12 @@ #include #include #include -#include +#include #include "./incomplete_type.h" #include "./resolve.h" #include "./type_subst.h" #include "./type_visitor.h" #include "./unifier.h" -// #include "tvm/relay/typeck/kindchecker.h" namespace tvm { namespace relay { @@ -378,11 +377,7 @@ FuncType TypeInferencer::instantiate(FuncType fn_ty, Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); inst_ty = TypeSubst(fn_ty, subst_map); - // if (!check_kind(t)) { - // this->fatal_error("Kind rules broken when instantiating type - // variables", - // t->span); - // } + CHECK(KindCheck(this->env, inst_ty)); return GetRef(inst_ty.as()); } From f813d51bdda644ee1b3e597831eedd48e2e88eea Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 2 Sep 2018 18:48:08 -0700 Subject: [PATCH 68/88] Improve type checking, can check control-flow-y program. --- include/tvm/relay/environment.h | 2 +- include/tvm/relay/expr.h | 2 + include/tvm/relay/pass.h | 1 + python/tvm/relay/env.py | 3 + python/tvm/relay/expr.py | 6 +- python/tvm/relay/ir_builder.py | 106 ++++++++----- python/tvm/relay/op/tensor.py | 20 +++ src/relay/ir/environment.cc | 11 +- src/relay/ir/expr.cc | 9 ++ src/relay/op/tensor/elemwise.cc | 35 +++++ src/relay/op/type_relations.cc | 143 ++++++++++-------- src/relay/op/type_relations.h | 1 + src/relay/pass/type_infer.cc | 72 +++------ src/relay/pass/unifier.cc | 9 +- .../relay/test_tyck_eval_integration.py | 38 ++++- 15 files changed, 299 insertions(+), 159 deletions(-) diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index ca5b8ac90df4..da782900fac5 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -78,7 +78,7 @@ class EnvironmentNode : public RelayNode { /*! \brief Add a source fragment to the environment. */ SourceName AddSource(std::string file_name, std::string source); - void ReportError(std::string msg, Span sp); + void AddDiagnostic(SpannedError); void DisplayErrors(); static constexpr const char* _type_key = "relay.Environment"; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 8ea3980dad46..7fd81ee0481b 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -215,6 +215,8 @@ class FunctionNode : public ExprNode { v->Visit("span", &span); } + Type fn_type() const; + TVM_DLL static Function make(tvm::Array params, Type ret_type, Expr body, tvm::Array ty_params); diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 738c6033147c..f92596c41179 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -26,6 +26,7 @@ namespace relay { * \return A type checked expression with its checked_type field populated. */ Expr InferType(const Environment& env, const Expr& e); +Expr InferType(const Environment& env, const GlobalVar & v, const Function & e); /*! * \brief Check that types are well formed by applying "kinding rules". diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 186ee8854c35..ee64ef6ce814 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -23,6 +23,9 @@ def add(self, var, func) -> None: def merge(self, other): return _env.Environment_Merge(self, other) + def global_var(self, var): + return _env.Environment_GetGlobalVar(self, var) + def lookup(self, var): if isinstance(var, str): return _env.Environment_Lookup_str(self, var) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 4f558210fb11..ec0cfd55ad62 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -9,7 +9,11 @@ from ._ir_pass import _get_checked_type from . import _make -class Expr(NodeBase): +class ExprBuilder(): + def __call__(self, *args, **kwargs): + return Call(self, args, None, None) + +class Expr(NodeBase, ExprBuilder): """The base type for all Relay exprressions.""" def checked_type(self): return _get_checked_type(self) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 50ebeb1bb12d..563c512639bc 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -2,27 +2,20 @@ import numpy as np import tvm from .type import FuncType, TensorType -from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function +from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function, If from .env import Environment from . import op as _op -class ExprBuilder(): - def __init__(self, expr): - self.expr = expr - - def __call__(self, *args): - return ExprBuilder(Call(self.expr, list(args), None, None)) - def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: """Convert Python values into the appropriate types for the Relay evaluator. """ if isinstance(arg, int): - return tvm.nd.array(arg, ctxt) + return tvm.nd.array(np.array(arg, dtype='int32'), ctxt) elif isinstance(arg, float): return tvm.nd.array(arg, ctxt) elif isinstance(arg, bool): - return tvm.nd.array(arg, ctxt) + return tvm.nd.array(np.array(arg, dtype='float32'), ctxt) elif isinstance(arg, np.ndarray): return tvm.nd.array(arg, ctxt) elif isinstance(arg, tvm.ndarray.NDArray): @@ -36,10 +29,10 @@ def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: raise Exception("..") else: value = convert(arg, ctxt) - return ExprBuilder(Constant(value)) + return Constant(value) class WithScope(object): - """Auxiliary scope with""" + """A wrapper for builder methods which introduce scoping.""" def __init__(self, enter_value, exit_cb): self._enter_value = enter_value @@ -49,7 +42,10 @@ def __enter__(self): return self._enter_value def __exit__(self, ptype, value, trace): - self._exit_cb() + if value: + raise value + else: + self._exit_cb() class PartialFunc(): @@ -77,15 +73,28 @@ def _mk_let(bindings, ret_value): return let_expr - class IRBuilder(): def __init__(self): self.bindings = [{}] self.scopes = [{}] self.params = [] - self.ret_value = None + self.ret_values = [None] self.env = Environment({}) + def enter_scope(self, params=[]): + self.bindings.append({}) + self.scopes.append({}) + self.params.append(params) + self.ret_values.append(None) + + + def exit_scope(self): + bindings = self.bindings.pop() + scopes = self.scopes.pop() + params = self.params.pop() + ret_value = self.ret_values.pop() + return bindings, scopes, params, ret_value + def bind(self, name, type, value): lv = LocalVar(name) @@ -98,12 +107,9 @@ def let(self, name, value, value_type=None): if isinstance(value, Param): value = value.var - if not (isinstance(value, Expr) or isinstance(value, ExprBuilder)): + if not isinstance(value, Expr): value = into_ast(value) - if isinstance(value, ExprBuilder): - value = value.expr - return self.bind(name, value_type, value) def function(self, *params): @@ -115,27 +121,52 @@ def function(self, *params): # self.params.append(relay_params) + self.enter_scope() + pfunc = PartialFunc(relay_params, None, None, []) def _on_exit(): - bindings = self.bindings.pop() - scope = self.scopes.pop() - ret_value = self.ret_value + bindings, scope, params, ret_value = self.exit_scope() body = _mk_let(bindings, ret_value) - self.ret_value = None pfunc.body = body - return WithScope(pfunc, _on_exit) def ret(self, x): - if not self.ret_value: - self.ret_value = x + if not self.ret_values[-1]: + self.ret_values[-1] = x else: raise Exception( "return value already set, a function can only have one return value") + def if_scope(self, cond): + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + assert self.ret_values[-1] is None + true_branch = _mk_let(bindings, ret_value) + self.ret_values[-1] = If(cond, true_branch, None) + + return WithScope(10, _on_exit) + + + def else_scope(self): + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + partial_if = self.ret_values[-1] + assert isinstance(partial_if, If) and partial_if.false_value is None + false_branch = _mk_let(bindings, ret_value) + self.ret_values[-1] = If( + partial_if.cond, + partial_if.true_value, + false_branch) + + return WithScope(10, _on_exit) + def param(self, name, ty=None): if not ty: ty = float_type() @@ -148,18 +179,21 @@ def param(self, name, ty=None): # arg = args[i] # if isinstance(arg, str): + def global_var(self, name: str): + return self.env.global_var(name) - def decl(self, name: str, *params): - decl_builder = IRBuilder() + def decl(self, name: str, *params, ret_type=None): + self.enter_scope() def _on_exit(): - exp, sub_env = decl_builder.get() - self.env.add(name, Function(params, None, exp)) - self.env.merge(sub_env) - - return WithScope(decl_builder, _on_exit) + bindings, _, _, ret_value = self.exit_scope() + exp = _mk_let(bindings, ret_value) + self.env.add(name, Function(params, ret_type, exp)) + return WithScope(10, _on_exit) + + # def while_loop(cond) def get(self): """Get the full program""" bindings = self.bindings.pop() @@ -171,16 +205,16 @@ def get(self): if self.scopes: raise Exception("IRBuilder: scoping error") - if bindings and scope and not self.ret_value: + if bindings and scope and not self.ret_values: raise Exception("IRBuilder: no return value set") - return _mk_let(bindings, self.ret_value), self.env + return _mk_let(bindings, self.ret_values[-1]), self.env def bool_dtype(): return 'uint1' def int_dtype(bits=32): - return f'int1{bits}' + return f'int{bits}' def float_dtype(bits=32): return f'float{bits}' diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index aa9ce6bf42e9..d0c1b88eb240 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -75,3 +75,23 @@ def add(lhs, rhs): The computed result. """ return _make.add(lhs, rhs) + +def subtract(lhs, rhs): + """Take sqrt of data. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.add(lhs, rhs) + +def equal(lhs, rhs): + return _make.equal(lhs, rhs) \ No newline at end of file diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 4c17f7cdbc89..a1a754615350 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -42,7 +42,8 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool update) { // Type check the item before we add it to the environment. auto env = GetRef(this); - Expr checked_expr = InferType(env, func); + + Expr checked_expr = InferType(env, var, func); if (const FunctionNode *func_node = checked_expr.as()) { auto checked_func = GetRef(func_node); @@ -104,13 +105,11 @@ void EnvironmentNode::Merge(const Environment & env) { inline SourceName EnvironmentNode::AddSource(std::string file_name, std::string source) { - throw Error("need to restore error handling"); - // return this->source_map_.add_source(file_name, source); + return this->source_map_.AddSource(file_name, source); } -void EnvironmentNode::ReportError(std::string msg, Span sp) { - throw Error("need to restore error handling"); - // this->errors_.push_back(Error(msg, sp)); +void EnvironmentNode::AddDiagnostic(SpannedError error) { + this->errors_.push_back(error); } void EnvironmentNode::DisplayErrors() { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 47d253e91c21..8dce7c054c8e 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -119,6 +119,15 @@ Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body, return Function(n); } +Type FunctionNode::fn_type() const { + Array param_types; + for (auto param : this->params) { + param_types.push_back(param->type); + } + + return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); +} + TVM_REGISTER_API("relay._make.Function") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index cd90705c6476..76adfbbfb968 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -85,5 +85,40 @@ RELAY_REGISTER_OP("add") // input2: Tensor[dtype, s2] // output: Tensor[dtype, broadcast(s1, s2)] +// Addition +TVM_REGISTER_API("relay.op._make.subtract") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("subtract"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("subtract") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_func("BroadcastComp", BroadcastCompRel); + + // def broadcast(s1, s2): + // ... + // + // input1: Tensor[dtype, s1] + // input2: Tensor[dtype, s2] + // output: Tensor[dtype, broadcast(s1, s2)] + +// Addition +TVM_REGISTER_API("relay.op._make.equal") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("equal"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("equal") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_func("BroadcastComp", BroadcastCompRel); + } // namespace relayv } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index d97b8f96e85c..32d81a1d445e 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -4,15 +4,15 @@ * \brief A set of utilities and common functionality * for type relations. */ -#include #include +#include #include #include "../pass/incomplete_type.h" namespace tvm { namespace relay { -TensorType as_ttype(const Type & t) { +TensorType as_ttype(const Type& t) { if (auto tt_node = t.as()) { return GetRef(tt_node); } else { @@ -21,94 +21,115 @@ TensorType as_ttype(const Type & t) { } // TODO(@jroesch) what size value do we extract? -int to_int(const tvm::Expr & e) { +int to_int(const tvm::Expr& e) { auto imm = e.as(); CHECK(imm); std::cout << "TYPE: " << imm << imm->type << std::endl; return imm->value; } -Array IdentityRel(const Array & types, int num_args) { - CHECK(types.size() == 2); - auto t1 = as_ttype(types[0]); - if (t1 && types[1].as()) { - return {t1, t1}; - } else { - return types; - } +Array IdentityRel(const Array& types, int num_args) { + CHECK(types.size() == 2); + auto t1 = as_ttype(types[0]); + if (t1 && types[1].as()) { + return {t1, t1}; + } else { + return types; + } } -static Type ConcreteBroadcast(const TensorType & t1, const TensorType & t2) { - RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 << std::endl; +static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, + DataType output_dtype) { + RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 + << std::endl; auto sh1 = t1->shape; auto sh2 = t2->shape; - RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2 << std::endl; - CHECK(sh1.size() > 0); - CHECK(sh2.size() > 0); - - auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); - auto full_len = static_cast(std::max(sh1.size(), sh2.size())); - - std::cout << "Length" << suffix_len << full_len << (full_len - suffix_len - 1) << std::endl; - auto lower_bound = full_len - suffix_len - 1; - - for (int64_t i = full_len - 1; i > lower_bound; i--) { - std::cout << "Index i=" << i << std::endl; - auto dim1 = to_int(sh1[i]); - auto dim2 = to_int(sh2[i]); - if (dim1 != dim2) { - CHECK(false); + RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2 + << std::endl; + if (sh1.size() == 0 && sh2.size() == 0) { + return TensorTypeNode::make({}, output_dtype); + // We have non-zero shapes so broadcast rules apply. + } else { + auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); + auto full_len = static_cast(std::max(sh1.size(), sh2.size())); + + std::cout << "Length" << suffix_len << full_len + << (full_len - suffix_len - 1) << std::endl; + auto lower_bound = full_len - suffix_len - 1; + + for (int64_t i = full_len - 1; i > lower_bound; i--) { + std::cout << "Index i=" << i << std::endl; + auto dim1 = to_int(sh1[i]); + auto dim2 = to_int(sh2[i]); + if (dim1 != dim2) { + CHECK(false); + } } - } - Array larger; - Array smaller; + Array larger; + Array smaller; - for (int i = 0; i < (full_len - suffix_len); i++) { - smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); - } + for (int i = 0; i < (full_len - suffix_len); i++) { + smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); + } - if (sh1.size() < sh2.size()) { - for (auto sh : sh1) { - smaller.push_back(sh); + if (sh1.size() < sh2.size()) { + for (auto sh : sh1) { + smaller.push_back(sh); + } + larger = sh2; + } else if (sh1.size() > sh2.size()) { + for (auto sh : sh1) { + larger.push_back(sh); + } + smaller = sh2; + } else { + larger = sh1; + smaller = sh2; } - larger = sh2; - } else if (sh1.size() > sh2.size()) { - for (auto sh : sh1) { - larger.push_back(sh); + + CHECK(larger.size() == smaller.size()); + + Array out_shape; + for (int i = 0; i < smaller.size(); i++) { + auto left = smaller[i].as(); + auto right = larger[i].as(); + CHECK(left); + CHECK(right); + int64_t dim = std::max(left->value, right->value); + out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); } - smaller = sh2; - } else { - larger = sh1; - smaller = sh2; - } - CHECK(larger.size() == smaller.size()); + return TensorTypeNode::make(out_shape, output_dtype); + } +} - Array out_shape; - for (int i = 0; i < smaller.size(); i++) { - auto left = smaller[i].as(); - auto right = larger[i].as(); - CHECK(left); - CHECK(right); - int64_t dim = std::max(left->value, right->value); - out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); +Array BroadcastRel(const Array& types, int num_args) { + CHECK(types.size() == 3); + if (auto t1 = as_ttype(types[0])) { + if (auto t2 = as_ttype(types[1])) { + std::cout << t1->dtype << t2->dtype << std::endl; + CHECK(t1->dtype == t2->dtype); + return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; + } } - return TensorTypeNode::make(out_shape, t1->dtype); + return types; } -Array BroadcastRel(const Array & types, int num_args) { +/* A relation which specifies broadcasting rules for operations which + compute boolean results. +*/ +Array BroadcastCompRel(const Array& types, int num_args) { CHECK(types.size() == 3); if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { - return { t1, t2, ConcreteBroadcast(t1, t2) }; + return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())}; } } return types; } - -} // namespace relayv +} // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index f2c4876705b6..71c98fef7da1 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -15,6 +15,7 @@ namespace relay { Array IdentityRel(const Array & types, int num_args); Array BroadcastRel(const Array & types, int num_args); +Array BroadcastCompRel(const Array & types, int num_args); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 894139c10b53..1adfb95d1e15 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -188,7 +188,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { GlobalVar var = GetRef(op); Expr e = this->env->Lookup(var); - return { var, e->checked_type() }; + return {var, e->checked_type()}; } CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { @@ -206,7 +206,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { field_types.push_back(checked_field.type); } - return { TupleNode::make(field_exprs), TupleTypeNode::make(field_types) }; + return {TupleNode::make(field_exprs), TupleTypeNode::make(field_types)}; } CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { @@ -488,21 +488,14 @@ CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { auto checked_cond = this->Infer(ifn->cond); auto cond_type = checked_cond.type; - if (const TensorTypeNode *tt_node = cond_type.as()) { - TensorType tt = GetRef(tt_node); - if (tt->dtype.is_bool() && tt->shape.size() == 0) { - auto checked_true = this->Infer(ifn->true_value); - auto checked_false = this->Infer(ifn->false_value); - auto unified_type = - this->unify(checked_true.type, checked_false.type, ifn->span); - auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, - checked_false.expr); - return {checked_if, unified_type}; - } - } - - this->fatal_error("if-then-else guard must be a rank-0 boolean tensor", - ifn->cond->span); + this->unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()), ifn->cond->span); + auto checked_true = this->Infer(ifn->true_value); + auto checked_false = this->Infer(ifn->false_value); + auto unified_type = + this->unify(checked_true.type, checked_false.type, ifn->span); + auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, + checked_false.expr); + return {checked_if, unified_type}; } CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { @@ -510,7 +503,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { return {op, op->op_type}; } -Type TypeInferencer::resolve(const Type &t) { +Type TypeInferencer::resolve(const Type& t) { if (t.defined()) { return ::tvm::relay::Resolve(this->unifier, t); } else { @@ -518,21 +511,11 @@ Type TypeInferencer::resolve(const Type &t) { } } -Expr TypeInferencer::resolve(const Expr &e) { +Expr TypeInferencer::resolve(const Expr& e) { CHECK(e.defined()); return ::tvm::relay::Resolve(this->unifier, e); } -void TypeInferencer::CheckOp(Op op) { - throw Error("NYI"); - // if (!check_kind(op->type)) { - // report_error("the type of the operator is ill formed", op->type->span); - // } - - // // Fix me - // return op; -} - // Defn TypeInferencer::CheckDefn(Defn defn) { // // This is to handle recursion, but we need to speculatively // // put it in env, then remove it. @@ -565,31 +548,24 @@ Expr InferType(const Environment &env, const Expr &e) { return checked_expr.expr; } -// Item Check(const Environment &env, const Item &i) { -// TypeInferencer tc(env); +Expr InferType(const Environment &env, const GlobalVar & var, const Function & func) { + TypeInferencer ti(env); + auto func_copy = FunctionNode::make(func->params, func->ret_type, func->body, func->type_params); + func_copy->checked_type_ = ti.resolve(func_copy->fn_type()); + env->functions.Set(var, func_copy); + auto checked_expr = ti.Infer(func); + auto map_node = env->functions.CopyOnWrite(); + map_node->data.erase(var.node_); + return checked_expr.expr; +} -// try { -// if (const DefnNode *defn = i.as()) { -// return tc.CheckDefn(GetRef(defn)); -// } else if (const OpNode *op_node = i.as()) { -// return tc.CheckOp(GetRef(op_node)); -// } else { -// throw dmlc::Error("internal error: unknown Item type"); -// } -// } catch (const FatalTypeError &err) { -// env->display_errors(); -// throw dmlc::Error( -// "We encountered a fatal error while type checking your program, -// please " "read above for more details."); -// } -// } inline void TypeInferencer::report_error(const std::string &msg, Span sp) { - // this->env->report_error(msg, sp); + this->env->AddDiagnostic({msg, sp}); } void TypeInferencer::fatal_error(const std::string &msg, Span sp) { - // this->env->report_error(msg, sp); + this->env->AddDiagnostic({msg, sp}); throw FatalTypeError( "internal error: this exception should" "be handled and errors reported with Environment::display_errors\n" + diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 4d986ad79ab1..4558f6a24919 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -180,6 +180,12 @@ Type TypeUnifierNode::VisitType(const Type & t1, const Type t2) { // When the right hand size is a type variable immediately unify. if (const IncompleteTypeNode *tvn2 = t2.as()) { return this->unifyWithIncompleteType(t1, GetRef(tvn2)); + // The TypeCallNode case is special and not symmetric. + // + // We flip the arguments so we hit the TypeCall and other case in there is + // ever a type call. + } else if (const TypeCallNode *tvn2 = t2.as()) { + return TypeFunctor::VisitType(t2, t1); } else { return TypeFunctor::VisitType(t1, t2); } @@ -353,7 +359,8 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { return TypeCallNode::make(unified_func, new_args); } else { - throw UnificationError("Cannot unify call with non-call"); + auto args = ty_call1->args; + return this->VisitType(args[args.size() - 1], t2); } } diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 7d42448a175b..1ae78441e166 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -2,9 +2,10 @@ for expressions. """ from tvm.relay.ir_pass import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type, func_type, tensor_type +from tvm.relay.ir_builder import IRBuilder, float_type, int_type +from tvm.relay.ir_builder import func_type, tensor_type, into_ast from tvm.relay.env import Environment -from tvm.relay.op import log, add +from tvm.relay.op import log, add, equal, subtract def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) @@ -40,6 +41,7 @@ def test_dual_op(): return t1; } """ + pass b = IRBuilder() with b.function(('x', tensor_type(10, 10))) as func: x, = func.param_ids() @@ -56,16 +58,42 @@ def f(x : Tensor[f32, (10, 10)]) { return lx; } """ + pass b = IRBuilder() x = b.param('x') - with b.decl('f', x) as d: - lx = d.let('lx', log(x)) - d.ret(lx) + with b.decl('f', x): + lx = b.let('lx', log(x)) + b.ret(lx) _, env = b.get() assert decl_has_type(env, 'f', func_type([float_type()], float_type())) +def test_recursion(): + """ + Program: + def f(n: i32, data: f32) -> f32 { + if (n == 0) { + return f(n - 1, log(data)); + } else { + return data; + } + } + f(2, 10000); + """ + b = IRBuilder() + f = b.global_var('f') + n = b.param('n', ty=int_type()) + data = b.param('data', ty=float_type()) + with b.decl(f, n, data): + with b.if_scope(equal(n, into_ast(0.0))): + b.ret(f(subtract(n, into_ast(1)), log(data))) + with b.else_scope(): + b.ret(data) + b.ret(f(into_ast(2.0), into_ast(10000.0))) + assert decl_has_type(b.env, 'f', func_type([int_type(), float_type()], float_type())) + if __name__ == "__main__": test_monomorphic_let() test_single_op() test_dual_op() test_decl() + test_recursion() From 2a2dc3f0d4238893beb620ff24f697c930f72ff9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 2 Sep 2018 18:52:14 -0700 Subject: [PATCH 69/88] Rename TVM compiler to to_tvm.py --- python/tvm/relay/ir_builder.py | 2 +- python/tvm/relay/{tvm_rts_backend.py => to_tvm.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename python/tvm/relay/{tvm_rts_backend.py => to_tvm.py} (100%) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 563c512639bc..a0a8c2e008da 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -220,7 +220,7 @@ def float_dtype(bits=32): return f'float{bits}' def uint_dtype(bits=32): - return f'fuint{bits}' + return f'uint{bits}' def int_type(bits=32, lanes=1): # TODO(@jroesch, @tqchen) How do we set lanes? diff --git a/python/tvm/relay/tvm_rts_backend.py b/python/tvm/relay/to_tvm.py similarity index 100% rename from python/tvm/relay/tvm_rts_backend.py rename to python/tvm/relay/to_tvm.py From 6b877cd8c964f6867585e047c9c36c95884c1b85 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 4 Sep 2018 13:49:44 -0700 Subject: [PATCH 70/88] Begin work on lowering Relay to TVM --- include/tvm/relay/base.h | 12 + include/tvm/relay/environment.h | 32 ++- include/tvm/relay/expr.h | 1 + include/tvm/relay/expr_visitor.h | 110 +++++---- include/tvm/relay/op.h | 3 +- python/tvm/relay/env.py | 3 + python/tvm/relay/ir_builder.py | 17 +- python/tvm/relay/ir_pass.py | 227 +++++++++++++++++- python/tvm/relay/op/__init__.py | 2 +- python/tvm/relay/op/op.py | 42 +++- python/tvm/relay/to_tvm.py | 49 ++-- src/relay/ir/environment.cc | 55 +++-- src/relay/ir/op.cc | 54 +++++ src/relay/pass/resolve.cc | 3 +- src/relay/pass/type_infer.cc | 7 +- .../relay/test_tyck_eval_integration.py | 53 +++- 16 files changed, 539 insertions(+), 131 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index e78c4b28e9ca..09f3a94e1edb 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -154,6 +154,18 @@ RefType GetRef(const NodeType* ptr) { return RefType(const_cast(ptr)->shared_from_this()); } +/*! + * \brief Get PackedFunction from global registry and + * report error if it does not exist + * \param name The name of the function. + * \return The created PackedFunc. + */ +inline const PackedFunc& GetPackedFunc(const std::string& name) { + const PackedFunc* pf = tvm::runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; + return *pf; +} + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index da782900fac5..29cde295398d 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -7,11 +7,11 @@ #ifndef TVM_RELAY_ENVIRONMENT_H_ #define TVM_RELAY_ENVIRONMENT_H_ +#include #include -#include #include -#include #include +#include #include #include @@ -24,8 +24,8 @@ struct Environment; * * The global environment contains the global * information needed to compile a Relay program. - * - * It contains all global functions, and configuration + * + * It contains all global functions, and configuration * options. * * Many operations require access to the global @@ -34,14 +34,15 @@ struct Environment; * but we will mutate the Environment while optimizing * Relay programs. * - * The functional style allows users to construct custom + * The functional style allows users to construct custom * environments easily, for example each thread can store * an Environment while auto-tuning. * */ class EnvironmentNode : public RelayNode { private: - /*! \brief A map from string names to global variables ensures global uniqueness. */ + /*! \brief A map from string names to global variables ensures global + * uniqueness. */ tvm::Map global_map_; /*! \brief A map from file names to source fragments. */ SourceMap source_map_; @@ -56,11 +57,10 @@ class EnvironmentNode : public RelayNode { void VisitAttrs(tvm::AttrVisitor* v) final {} - TVM_DLL static Environment make( - tvm::Map global_funcs); + TVM_DLL static Environment make(tvm::Map global_funcs); - void Add(const GlobalVar& var, const Function & func, bool update = false); - void Update(const GlobalVar& var, const Function & func); + void Add(const GlobalVar& var, const Function& func, bool update = false); + void Update(const GlobalVar& var, const Function& func); void Remove(const GlobalVar& var); /*! \brief Lookup a global function by its variable. */ @@ -70,14 +70,20 @@ class EnvironmentNode : public RelayNode { Function Lookup(const GlobalVar& id); /*! \brief Lookup a global function by its string name */ - Function Lookup(const std::string & s); - + Function Lookup(const std::string& s); + // TODO(@jroesch, @tqchen): what are the semantics here - void Merge(const Environment & env); + void Merge(const Environment& env); /*! \brief Add a source fragment to the environment. */ SourceName AddSource(std::string file_name, std::string source); + using Transformer = runtime::TypedPackedFunc< + runtime::TypedPackedFunc(const Environment&)>; + + /*! \brief Apply a function over every function in the global environment. */ + void Transform(Transformer tranformer); + void AddDiagnostic(SpannedError); void DisplayErrors(); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 7fd81ee0481b..a882b7cc1ea7 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -271,6 +271,7 @@ class CallNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("op", &op); v->Visit("args", &args); + v->Visit("attrs", &attrs); v->Visit("type_args", &type_args); v->Visit("span", &span); } diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index e15f25a39eb3..6f2a7f98542a 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -15,95 +15,99 @@ namespace tvm { namespace relay { -template -class ExprVisitor : public ::tvm::relay::ExprFunctor { +class ExprVisitor : public ::tvm::relay::ExprFunctor { public: - void VisitExpr_(const LocalVarNode* op, Args... args) override { return; } + void VisitExpr_(const LocalVarNode* op) override { return; } - void VisitExpr_(const GlobalVarNode* op, Args... args) override { return; } + void VisitExpr_(const GlobalVarNode* op) override { return; } - void VisitExpr_(const ConstantNode* op, Args... args) override { return; } + void VisitExpr_(const ConstantNode* op) override { return; } - void VisitExpr_(const TupleNode* op, Args... args) override { + void VisitExpr_(const TupleNode* op) override { for (auto field : op->fields) { - this->VisitExpr(field, args...); + this->VisitExpr(field); } } - void VisitExpr_(const ParamNode* op, Args... args) override { - this->VisitExpr(op->var, args...); + void VisitExpr_(const ParamNode* op) override { + this->VisitExpr(op->var); } - void VisitExpr_(const FunctionNode* op, Args... args) override { + void VisitExpr_(const FunctionNode* op) override { for (auto param : op->params) { - this->VisitExpr(param, args...); + this->VisitExpr(param); } - this->VisitExpr(op->body, args...); + this->VisitExpr(op->body); } - void VisitExpr_(const CallNode* op, Args... args) override { - this->VisitExpr(op->op, args...); + void VisitExpr_(const CallNode* op) override { + this->VisitExpr(op->op); + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + for (auto arg : op->args) { - this->VisitExpr(arg, args...); + this->VisitExpr(arg); } } - void VisitExpr_(const LetNode* op, Args... args) override { - this->VisitExpr(op->var, args...); - this->VisitExpr(op->value, args...); - this->VisitExpr(op->body, args...); + void VisitExpr_(const LetNode* op) override { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + this->VisitExpr(op->body); } - void VisitExpr_(const IfNode* op, Args... args) override { - this->VisitExpr(op->cond, args...); - this->VisitExpr(op->true_value, args...); - this->VisitExpr(op->false_value, args...); + void VisitExpr_(const IfNode* op) override { + this->VisitExpr(op->cond); + this->VisitExpr(op->true_value); + this->VisitExpr(op->false_value); } - void VisitExpr_(const OpNode* op, Args... args) override { return; } + void VisitExpr_(const OpNode* op) override { return; } + + virtual void VisitType(const Type& t) {} }; -template -class ExprFVisitor : public ::tvm::relay::ExprFunctor { +class ExprFVisitor : public ::tvm::relay::ExprFunctor { public: - Expr VisitExpr_(const LocalVarNode* op, Args... args) override { + Expr VisitExpr_(const LocalVarNode* op) override { return GetRef(op); } - Expr VisitExpr_(const GlobalVarNode* op, Args... args) override { + Expr VisitExpr_(const GlobalVarNode* op) override { return GetRef(op); } - Expr VisitExpr_(const OpNode* op, Args... args) override { + Expr VisitExpr_(const OpNode* op) override { return GetRef(op); } - Expr VisitExpr_(const TupleNode* op, Args... args) override { + Expr VisitExpr_(const TupleNode* op) override { tvm::Array fields; for (auto field : op->fields) { - fields.push_back(this->VisitExpr(field, args...)); + fields.push_back(this->VisitExpr(field)); } return TupleNode::make(fields); } - Expr VisitExpr_(const ParamNode* op, Args... args) override { - Expr var_expr = this->VisitExpr(op->var, args...); + Expr VisitExpr_(const ParamNode* op) override { + Expr var_expr = this->VisitExpr(op->var); if (const LocalVarNode* var_node = var_expr.as()) { auto var = GetRef(var_node); - auto type = this->VisitType(op->type, args...); + auto type = this->VisitType(op->type); return ParamNode::make(var, type); } else { throw dmlc::Error("the default param visitor has bug"); } } - Expr VisitExpr_(const FunctionNode* op, Args... args) override { + Expr VisitExpr_(const FunctionNode* op) override { tvm::Array ty_params; for (auto ty : op->type_params) { - Type ty_param_type = VisitType(ty, args...); + Type ty_param_type = VisitType(ty); if (auto ty_param = ty_param_type.as()) { auto ty_param_ref = GetRef(ty_param); ty_params.push_back(ty_param_ref); @@ -114,7 +118,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor params; for (auto param : op->params) { - Expr param_expr = this->VisitExpr(param, args...); + Expr param_expr = this->VisitExpr(param); if (const ParamNode* param_node = param_expr.as()) { auto param = GetRef(param_node); params.push_back(param); @@ -123,23 +127,23 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitType(op->ret_type, args...); - auto body = this->VisitExpr(op->body, args...); + auto ret_type = this->VisitType(op->ret_type); + auto body = this->VisitExpr(op->body); return FunctionNode::make(params, ret_type, body, ty_params); } - Expr VisitExpr_(const CallNode* call_node, Args... args) override { - auto fn = this->VisitExpr(call_node->op, args...); + Expr VisitExpr_(const CallNode* call_node) override { + auto fn = this->VisitExpr(call_node->op); tvm::Array ty_args; for (auto ty_arg : call_node->type_args) { - auto new_ty_arg = this->VisitType(ty_arg, args...); + auto new_ty_arg = this->VisitType(ty_arg); ty_args.push_back(new_ty_arg); } tvm::Array call_args; for (auto arg : call_node->args) { - call_args.push_back(this->VisitExpr(arg, args...)); + call_args.push_back(this->VisitExpr(arg)); } auto call = CallNode::make(fn, call_args, call_node->attrs, ty_args); @@ -147,27 +151,27 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitExpr(op->var, args...); + Expr VisitExpr_(const LetNode* op) override { + Expr var_expr = this->VisitExpr(op->var); if (const LocalVarNode* var_node = var_expr.as()) { auto var = GetRef(var_node); - auto type = this->VisitType(op->value_type, args...); - auto value = this->VisitExpr(op->value, args...); - auto body = this->VisitExpr(op->body, args...); + auto type = this->VisitType(op->value_type); + auto value = this->VisitExpr(op->value); + auto body = this->VisitExpr(op->body); return LetNode::make(var, value, body, type); } else { throw dmlc::Error("the default let visitor has error"); } } - Expr VisitExpr_(const IfNode* op, Args... args) override { - auto guard = this->VisitExpr(op->cond, args...); - auto true_b = this->VisitExpr(op->true_value, args...); - auto false_b = this->VisitExpr(op->false_value, args...); + Expr VisitExpr_(const IfNode* op) override { + auto guard = this->VisitExpr(op->cond); + auto true_b = this->VisitExpr(op->true_value); + auto false_b = this->VisitExpr(op->false_value); return IfNode::make(guard, true_b, false_b); } - virtual Type VisitType(const Type& t, Args... args) { return t; } + virtual Type VisitType(const Type& t) { return t; } }; } // namespace relay diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 0e5483174c53..2e8d090f6625 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -34,7 +34,7 @@ class OpNode : public relay::ExprNode { public: /*! \brief name of the operator */ std::string name; - + /*! \brief the type of the operator */ Type op_type; /*! * \brief detailed description of the operator @@ -62,6 +62,7 @@ class OpNode : public relay::ExprNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name", &name); + v->Visit("op_type", &op_type); v->Visit("description", &description); v->Visit("arguments", &arguments); v->Visit("attrs_type_key", &attrs_type_key); diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index ee64ef6ce814..86c9ac794b4e 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -31,3 +31,6 @@ def lookup(self, var): return _env.Environment_Lookup_str(self, var) else: return _env.Environment_Lookup(self, var) + + def transform(self, transformer): + _env.Environment_Transform(self, transformer) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index a0a8c2e008da..098eb474c6ee 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -25,8 +25,12 @@ def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: raise Exception(f"unsupported argument type {type(arg)}") def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: - if isinstance(arg, tuple): + if isinstance(arg, Expr): + return arg + elif isinstance(arg, tuple): raise Exception("..") + elif isinstance(arg, PartialFunc): + return arg.to_func() else: value = convert(arg, ctxt) return Constant(value) @@ -114,10 +118,11 @@ def let(self, name, value, value_type=None): def function(self, *params): relay_params = [] - for name, ty in params: - lv = LocalVar(name) - self.scopes[-1][name] = lv - relay_params.append(Param(lv, ty)) + for param in params: + name = param.var + ty = param.type + self.scopes[-1][name.name_hint] = name + relay_params.append(Param(name, ty)) # self.params.append(relay_params) @@ -135,7 +140,7 @@ def _on_exit(): def ret(self, x): if not self.ret_values[-1]: - self.ret_values[-1] = x + self.ret_values[-1] = into_ast(x) else: raise Exception( "return value already set, a function can only have one return value") diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index ad7a68eac392..70d5f09237d8 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -1,5 +1,230 @@ -#pylint: disable-all +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +The optimizer for Relay. +Exposes an interface for configuring the optimizer and scripting +it directly in Python. +""" +from typing import TypeVar, Generic, Union +from typing import Dict, Tuple, List, Callable +import tvm + +from .expr import Expr +from .expr import Function, Let, Call, LocalVar +from .expr import GlobalVar, If, Constant +from .type import Type +from .env import Environment +from .op import Op +# import relay.make as relay_mk +# from relay import ir +# from relay.env import Environment +# from relay.tyck import check_expr +# from relay.first_order_reverse_ad import fo_with_gradient +# from relay.anf import to_anf from . import _ir_pass +# Expose checking expression, should rename to infer_type. check_expr = _ir_pass.check_expr + +# # pylint: disable=invalid-name +# concretize = _opt.concretize + +# # pylint: disable=invalid-name +# optimize = _opt.optimize + +# # pylint: disable=invalid-name +# type_specialize = _opt.type_specialize + +# # pylint: disable=invalid-name +# compile_ops_to_module = _opt.compile_ops_to_module + + +@tvm.register_func("relay.mangle") +def mangle(name: str, types: List[Type]) -> str: + for typ in types: + name += str(typ) + "_" + return name + +T = TypeVar('T') +class AbstractExprVisitor(Generic[T]): + """A functional visitor over Expr in Python.""" + + # pylint: disable=no-else-return + def visit(self, expr: Expr) -> T: + """Apply the visitor to an expression.""" + if isinstance(expr, Function): + return self.visit_function(expr) + elif isinstance(expr, Call): + return self.visit_call(expr) + elif isinstance(expr, Let): + return self.visit_let(expr) + elif isinstance(expr, LocalVar): + return self.visit_local_var(expr) + elif isinstance(expr, GlobalVar): + return self.visit_global_var(expr) + elif isinstance(expr, If): + return self.visit_if(expr) + elif isinstance(expr, Tuple): + return self.visit_tuple(expr) + elif isinstance(expr, Constant): + return self.visit_constant(expr) + else: + raise Exception(f"warning unhandled case: {type(expr)}") + + def visit_function(self, _: Function) -> T: + raise Exception("Abstract method please implement me.") + + def visit_let(self, _: Let) -> T: + raise Exception("Abstract method please implement me.") + + def visit_call(self, _: Call) -> T: + raise Exception("Abstract method please implement me.") + + def visit_local_id(self, _: LocalVar) -> T: + raise Exception("Abstract method please implement me.") + + def visit_type(self, typ: Type) -> Type: + return typ + + def visit_if(self, _: If) -> T: + raise Exception("Abstract method please implement me.") + + def visit_tuple(self, _: Tuple) -> T: + raise Exception("Abstract method please implement me.") + + def visit_constant(self, _: Constant) -> T: + raise Exception("Abstract method please implement me.") + + def visit_global_var(self, _: GlobalVar) -> T: + raise Exception("Abstract method please implement me.") + + @classmethod + def to_pass(cls) -> Callable[[Environment], Callable[[GlobalVar, Function], Function]]: + def _outer_wrapper(env): + visitor = cls(env) + def _inner_wrapper(var, func): + return visitor.visit(func) + return _inner_wrapper + return _outer_wrapper + +class ExprVisitor(AbstractExprVisitor[Expr]): + """A functional visitor over Expr in Python.""" + + def visit_function(self, fn: Function) -> Expr: + new_body = self.visit(fn.body) + return Function( + list(fn.params), + fn.ret_type, new_body, + fn.type_params) + + def visit_let(self, let: Let) -> Expr: + new_var = self.visit(let.var) + new_value_type = self.visit_type(let.value_type) + new_val = self.visit(let.value) + new_body = self.visit(let.body) + return Let(new_var, new_val, new_body, new_value_type) + + def visit_call(self, call: Call) -> Expr: + new_fn = self.visit(call.fn) + new_args = [self.visit(arg) for arg in call.args] + return Call(new_fn, new_args, call.attrs) + + def visit_local_var(self, local_var: LocalVar) -> Expr: + return local_var + + def visit_global_id(self, global_var: GlobalVar) -> Expr: + return global_var + + def visit_if(self, ite: If) -> Expr: + return If( + self.visit(ite.guard), + self.visit(ite.true_b), + self.visit(ite.false_b)) + + def visit_tuple(self, tup: Tuple) -> Expr: + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_constant(self, const: Constant) -> Expr: + return const + +MMCacheKey = Tuple[GlobalVar, List[Type]] + +class Monomorphize(ExprVisitor): + """A monomorphization pass. + + Implements what is known as "monomorphization" in + classic compiler literature. This pass removes + polymorphism replacing calls to functions and + operators with type specialized versions. + """ + monomorph_map: Dict[MMCacheKey, Union[Op, Function]] + + # pylint: disable=super-init-not-called + def __init__(self, env: Environment) -> None: + self.env = env + # Stores (GlobalVar, Type), should eventually store attributes. + self.monomorph_map = {} + + # pylint: disable=no-else-return + def visit_call(self, call: Call) -> Expr: + import pdb; pdb.set_trace() + # cache_key = (call.fn, call.ty_args) + # if isinstance(call.fn, OperatorId): + # if cache_key in self.monomorph_map: + # op = self.monomorph_map[cache_key] + # new_args = [self.visit(arg) for arg in call.args] + # return Call(op, new_args, call.attrs) + # else: + # new_name = mangle(call.fn.name, call.ty_args) + # new_id = self.env.operator_id(new_name) + # self.monomorph_map[cache_key] = new_id + # op = self.env.lookup(call.fn) + # for arg in call.ty_args: + # if isinstance(arg, TypeParam): + # return call # raise Exception("...") # Fix me in the morning!!! + # new_op = concretize(new_id, op, call.ty_args, call.attrs) + # self.monomorph_map[cache_key] = new_op.id + # self.env.add(new_op) + # new_args = [self.visit(arg) for arg in call.args] + # return Call(new_op.id, new_args, call.attrs) + # elif isinstance(call.fn, GlobalVar): + # if cache_key in self.monomorph_map: + # op_name = self.monomorph_map[cache_key] + # new_args = [self.visit(arg) for arg in call.args] + # return Call(op_name, new_args, call.attrs) + # else: + # defn = self.env.lookup(call.fn) + # new_id = self.env.global_id(defn.id.name + str(1)) + # cache_key = (call.fn, call.ty_args) + # self.monomorph_map[cache_key] = new_id + # new_body = self.visit(type_specialize(call.ty_args, defn.body)) + # new_body = Function( + # [], new_body.params, new_body.ret_type, new_body.body) + # new_ty = check_expr(self.env, new_body) + # # TODO(@jroesch): move into C++ + # # TODO(@joresch): implement and call name mangler + # defn = Defn(new_id, new_ty, new_body) + # self.env.add(defn) + # self.visit_item(defn) + # return Call(new_id, call.args, call.attrs) + # elif isinstance(call.fn, Function): + # new_args = [self.visit(arg) for arg in call.args] + # new_func = type_specialize(call.ty_args, call.fn) + # new_func = self.visit(new_func) + # new_func = Function([], + # new_func.params, + # new_func.ret_type, + # new_func.body) + # check_expr(self.env, new_func) + # return Call(new_func, call.args, call.attrs) + # else: + # new_fn = self.visit(call.fn) + # new_args = [self.visit(arg) for arg in call.args] + # return Call(new_fn, new_args, call.attrs) + + +# TODO(@jroesch): Fix up my type +__tgt_host__ = __tgt__ = "llvm" +__relay_tvm_context__ = tvm.cpu() + diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index d54f47e25197..47ebc5501cab 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,6 +1,6 @@ """Relay core operators.""" # operator defs -from .op import get, register, Op +from .op import get, register, Op, compile_ops # Operators from .tensor import * diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 4540b19f5ccf..d351e6cdc88d 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -3,7 +3,7 @@ from ..base import register_relay_node from ..expr import Expr -from ..._ffi.function import Function +from ..._ffi.function import Function, register_func from ...api import convert @register_relay_node @@ -72,6 +72,46 @@ def _register(v): return v return _register(value) if value else _register +def compile_ops(op_names): + """Register an operator property of an operator. + + + Parameters + ---------- + op_name : str + The name of operator + + attr_key : str + The attribute name. + + value : object, optional + The value to set + + level : int, optional + The priority level + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + fake_map = {} + for name in op_names: + fake_map[name] = LocalVar(name) + if isinstance({}, dict): + fake_map = None + return [] # _CompileOpsToModule(fake_map) + +# TODO(@jroesch): We should port to C++, just need to figure out how to write this code. +@register_func("relay.opt.compile_ops") +def _compile_ops(op_impls): + lowered = [] + for local, sch, inputs in op_impls: + lfn = tvm.lower(sch, inputs, name=local.name_hint) + lowered.append(lfn) + + # TOOD(@jroesch): Where should we read these settings from + return tvm.build(lowered, target='llvm', target_host=tvm.cpu(0)) _init_api("relay.op", __name__) diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py index 137230ace63a..d191e078dffe 100644 --- a/python/tvm/relay/to_tvm.py +++ b/python/tvm/relay/to_tvm.py @@ -4,12 +4,10 @@ from typing import Dict, Any, List, Tuple import attr - -from relay.frontend import get_env -from . import ir -from .tyck import get_checked_type -from .opt import AbstractExprVisitor, compile_ops_to_module -from ._make import Operator_is_generic +from .ir_pass import AbstractExprVisitor +from .op import compile_ops +from .type import TensorType +from .expr import LocalVar, Function, Let, Call @attr.s(auto_attribs=True) @@ -71,7 +69,7 @@ def to_json(self) -> Any: } -def from_tensor(typ: ir.TensorType) -> Tuple[str, List[int]]: +def from_tensor(typ: TensorType) -> Tuple[str, List[int]]: dtype = typ.dtype.dtype shape = typ.shape dims = [] @@ -83,7 +81,7 @@ def from_tensor(typ: ir.TensorType) -> Tuple[str, List[int]]: class TVMRTSCompiler(AbstractExprVisitor[NodeRef]): """The compiler from Relay to the TVM runtime system.""" nodes: List[Node] - id_map: Dict[ir.LocalId, NodeRef] + id_map: Dict[LocalVar, NodeRef] def __init__(self) -> None: self.nodes = [] @@ -94,10 +92,10 @@ def add_node(self, node: Node) -> NodeRef: ident = len(self.nodes) - 1 return NodeRef(ident) - def add_binding(self, ident: ir.LocalId, ref: NodeRef) -> None: + def add_binding(self, ident: LocalVar, ref: NodeRef) -> None: self.id_map[ident] = ref - def let_bind(self, ident: ir.LocalId, node: Node) -> NodeRef: + def let_bind(self, ident: LocalVar, node: Node) -> NodeRef: ref = self.add_node(node) self.add_binding(ident, ref) return ref @@ -105,10 +103,10 @@ def let_bind(self, ident: ir.LocalId, node: Node) -> NodeRef: def get_node(self, ref: NodeRef) -> Node: return self.nodes[ref.ident] - def lookup(self, ident: ir.LocalId) -> NodeRef: + def lookup(self, ident: LocalVar) -> NodeRef: return self.id_map[ident] - def compile(self, func: ir.Function) -> None: + def compile(self, func: Function) -> None: """Compile a single function into a graph.""" # TODO: (@jroesch) Restore me # assert len(fn.ty_params) == 0 @@ -132,30 +130,30 @@ def compile(self, func: ir.Function) -> None: # become our output node. self.get_node(output_ref).is_output = True - def visit_let(self, let: ir.Let) -> NodeRef: + def visit_let(self, let: Let) -> NodeRef: """Visit the Let binding, by first traversing its value, then setting the metadata on the returned NodeRef. Finally visit the body, and return the NodeRef corresponding to it. """ - ident = let.id + ident = let.var val = let.value body = let.body # Need to add type info? val_ref = self.visit(val) - dtype, shape = from_tensor(get_checked_type(val)) + dtype, shape = from_tensor(val.checked_type()) val_node = self.get_node(val_ref) val_node.attrs["dtype"] = dtype val_node.attrs["shape"] = shape self.add_binding(ident, val_ref) return self.visit(body) - def visit_local_id(self, ident: ir.LocalId) -> NodeRef: + def visit_local_id(self, ident: LocalVar) -> NodeRef: return self.lookup(ident) - def visit_call(self, call: ir.Call) -> NodeRef: + def visit_call(self, call: Call) -> NodeRef: inputs = [] for arg in call.args: inputs.append(self.visit(arg).to_json()) @@ -219,20 +217,21 @@ def to_json(self) -> str: return json.dumps(json_dict) -def compile_to_tvm(func): +def compile(func): """Compile a single function to the components needed by the TVM RTS. """ - env = get_env() - iids = [] + op_names = [] - # Why do I need to call items? - for op in env.operators(): - if not Operator_is_generic(op): - iids.append(op.id) + # # Why do I need to call items? + # for op in env.operators(): + # if not Operator_is_generic(op): + # iids.append(op.id) # TODO(@jroesch): Need to write test case for this - mod = compile_ops_to_module(env, iids) + print("above") + mod = compile_ops(op_names) + print("below") comp = TVMRTSCompiler() comp.compile(func) graph_json = comp.to_json() diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index a1a754615350..db7f11fb9e2b 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -3,10 +3,10 @@ * \file environment.cc * \brief The global environment in Relay. */ -#include #include -#include #include +#include +#include #include "./../pass/resolve.h" // #include "tvm/relay/util/rang.h" @@ -16,8 +16,7 @@ namespace relay { using tvm::IRPrinter; using namespace runtime; -Environment EnvironmentNode::make( - tvm::Map global_funcs) { +Environment EnvironmentNode::make(tvm::Map global_funcs) { std::shared_ptr n = std::make_shared(); n->functions = std::move(global_funcs); return Environment(n); @@ -26,11 +25,11 @@ Environment EnvironmentNode::make( GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { auto global_id = global_map_.find(str); if (global_id != global_map_.end()) { - return (*global_id).second; + return (*global_id).second; } else { - auto id = GlobalVarNode::make(str); - this->global_map_.Set(str, id); - return id; + auto id = GlobalVarNode::make(str); + this->global_map_.Set(str, id); + return id; } } @@ -39,7 +38,8 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { * definition will trigger an exception, otherwise we will * update the definition if and only if it is type compatible. */ -void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool update) { +void EnvironmentNode::Add(const GlobalVar &var, const Function &func, + bool update) { // Type check the item before we add it to the environment. auto env = GetRef(this); @@ -72,14 +72,14 @@ void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool upda } } -void EnvironmentNode::Update(const GlobalVar& var, const Function & func) { +void EnvironmentNode::Update(const GlobalVar &var, const Function &func) { this->Add(var, func, true); } -void EnvironmentNode::Remove(const GlobalVar&) { +void EnvironmentNode::Remove(const GlobalVar &) { // Clarify with @tqchen about how to use COW to do this. throw Error("NYI"); - // this->items.erase(id); + // this->items.erase(id); } Function EnvironmentNode::Lookup(const GlobalVar &var) { @@ -96,15 +96,14 @@ Function EnvironmentNode::Lookup(const std::string &str) { return this->Lookup(id); } -void EnvironmentNode::Merge(const Environment & env) { +void EnvironmentNode::Merge(const Environment &env) { for (auto pair : env->functions) { this->functions.Set(pair.first, pair.second); } } - inline SourceName EnvironmentNode::AddSource(std::string file_name, - std::string source) { + std::string source) { return this->source_map_.AddSource(file_name, source); } @@ -130,18 +129,35 @@ void EnvironmentNode::DisplayErrors() { // // Build the cursor. // // Fix this code, hardwired to compute alignment of pointer. - // size_t spaces = error_marker.size() + line_info.size() + file_name.size() + + // size_t spaces = error_marker.size() + line_info.size() + file_name.size() + // + // sp->col_offset - 3; // std::string cursor = "~~~~^~~~~"; // for (size_t i = 0; i < spaces; i++) { // std::cout << " "; // } - // std::cout << rang::fg::red << cursor << " " << err.msg << rang::style::reset + // std::cout << rang::fg::red << cursor << " " << err.msg << + // rang::style::reset // << std::endl; // } } +void EnvironmentNode::Transform(EnvironmentNode::Transformer transformer) { + Array to_process; + for (auto var_and_func : this->functions) { + to_process.push_back(var_and_func.first); + } + + auto for_each = transformer(GetRef(this)); + for (auto var : to_process) { + auto func = this->functions[var]; + auto transformed = for_each(var, func); + this->Add(var, transformed, true); + } +} + + TVM_REGISTER_API("relay._make.Environment") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = EnvironmentNode::make(args[0]); @@ -180,6 +196,11 @@ TVM_REGISTER_API("relay._env.Environment_Merge") env->Merge(args[1]); }); +TVM_REGISTER_API("relay._env.Environment_Transform") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Transform(args[1]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const EnvironmentNode *node, diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 664947425b53..e02a3163e8e7 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -1,4 +1,8 @@ +#include #include +#include +#include + #include #include @@ -132,5 +136,55 @@ TVM_REGISTER_API("relay.op._Register") } }); +bool IsGeneric(const Op& op) { + if (auto ty_func = op.as()) { + return ty_func->type_params.size() == 0; + } else { + return false; + } +} + +using namespace runtime; + +Module CompileOpsToModule(const std::vector & op_names) { + PackedFunc compile_ops = GetPackedFunc("relay.op.compile_ops"); + tvm::Array> args; + + auto compiler_map = Op::GetAttr("FRelayOpCompiler"); + + for (auto op_name : op_names) { + Op op = Op::Get(op_name); + + if (IsGeneric(op)) { + auto compiler = compiler_map[op]; + tvm::Array pair = + compiler(op->name, op->op_type); + //TODO(@jroesch): I can't pass strings across what should be the interface here. + tvm::Array triple = {LocalVarNode::make(op->name), pair[0], pair[1]}; + args.push_back(triple); + } else { + throw dmlc::Error("it is impossible to compile generic operators."); + } + } + + // Nothing to do, bail out earlier. + // TVM will complain if we try to generate a module of size 0. + if (args.size() == 0) { + return Module(nullptr); + } + + return compile_ops(args); +} + +TVM_REGISTER_API("relay.op._CompileOpsToModule") +.set_body([](TVMArgs args, TVMRetValue* ret) { + tvm::Map map = args[0]; + std::vector names; + for (auto pair : map) { + names.push_back(pair.first); + } + *ret = CompileOpsToModule(names); +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index f18a67bcffc9..f513e36c9a30 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -33,7 +33,7 @@ struct ResolveTypeType : TypeFVisitor { } }; -struct ResolveTypeExpr : ExprFVisitor<> { +struct ResolveTypeExpr : ExprFVisitor { const TypeUnifier &unifier; explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {} @@ -53,6 +53,7 @@ struct ResolveTypeExpr : ExprFVisitor<> { // term, then resolve e's old type and write // it back into the new node. auto new_e = ExprFVisitor::VisitExpr(e); + std::cout << e << std::endl; CHECK(e->checked_type_.defined()); auto resolved_cty = VisitType(e->checked_type_); new_e->checked_type_ = resolved_cty; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 1adfb95d1e15..2ea205b511b1 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -211,6 +211,9 @@ CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { auto rtype = resolve(param->type); + // This is a special case ... not sure if there is a better way + // to handle this. + param->var->checked_type_ = rtype; return {ParamNode::make(param->var, rtype), rtype}; } @@ -545,7 +548,7 @@ Expr TypeInferencer::resolve(const Expr& e) { Expr InferType(const Environment &env, const Expr &e) { TypeInferencer ti(env); auto checked_expr = ti.Infer(e); - return checked_expr.expr; + return ti.resolve(checked_expr.expr); } Expr InferType(const Environment &env, const GlobalVar & var, const Function & func) { @@ -556,7 +559,7 @@ Expr InferType(const Environment &env, const GlobalVar & var, const Function & f auto checked_expr = ti.Infer(func); auto map_node = env->functions.CopyOnWrite(); map_node->data.erase(var.node_); - return checked_expr.expr; + return ti.resolve(checked_expr.expr); } diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 1ae78441e166..51833e13e475 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -5,7 +5,10 @@ from tvm.relay.ir_builder import IRBuilder, float_type, int_type from tvm.relay.ir_builder import func_type, tensor_type, into_ast from tvm.relay.env import Environment +from tvm.relay.ir_pass import Monomorphize from tvm.relay.op import log, add, equal, subtract +from tvm.relay.expr import Function +from tvm.relay import to_tvm def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) @@ -15,14 +18,26 @@ def decl_has_type(env, name, typ): func = env.lookup(name) return func.checked_type() == typ + +def run(env, expr): + if not isinstance(expr, Function): + expr = Function([], None, expr, []) + + env.add("main", expr) + env.transform(Monomorphize.to_pass()) + main = env.lookup("main") + graph_json, mod, _ = to_tvm.compile(main) + import pdb; pdb.set_trace() + def test_monomorphic_let(): "Program: let x = 1; return x" b = IRBuilder() x = b.let('x', 1.0, value_type=float_type(64)) b.ret(x) - prog, _ = b.get() + prog, env = b.get() assert has_type(prog, float_type(64)) + run(env, prog) def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" @@ -33,6 +48,25 @@ def test_single_op(): b.ret(t1) assert has_type(func.to_func(), func_type([float_type()], float_type())) +def test_binary_op(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(5, 5, 5)) + y = b.param('y', tensor_type(5, 5, 5)) + with b.function(x, y) as func: + b.ret(add(x, y)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert has_type(func.to_func(), expected_ty) + run(env, prog) + def test_dual_op(): """Program: fn (x : Tensor[f32, (10, 10)]) { @@ -40,8 +74,7 @@ def test_dual_op(): let t2 = add(t1, x); return t1; } - """ - pass + """ b = IRBuilder() with b.function(('x', tensor_type(10, 10))) as func: x, = func.param_ids() @@ -58,7 +91,6 @@ def f(x : Tensor[f32, (10, 10)]) { return lx; } """ - pass b = IRBuilder() x = b.param('x') with b.decl('f', x): @@ -90,10 +122,11 @@ def f(n: i32, data: f32) -> f32 { b.ret(data) b.ret(f(into_ast(2.0), into_ast(10000.0))) assert decl_has_type(b.env, 'f', func_type([int_type(), float_type()], float_type())) - + if __name__ == "__main__": - test_monomorphic_let() - test_single_op() - test_dual_op() - test_decl() - test_recursion() + # test_monomorphic_let() + # test_single_op() + test_binary_op() + # test_dual_op() + # test_decl() + # test_recursion() From 888914afb6813f7c459fe5534cf7fe22dafc49a7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 5 Sep 2018 15:50:08 -0700 Subject: [PATCH 71/88] WIP debugging --- include/tvm/relay/op.h | 2 +- nnvm/python/nnvm/_base.py | 7 +- python/tvm/relay/expr.py | 9 + python/tvm/relay/ir_pass.py | 109 ++++++----- python/tvm/relay/op/_tensor.py | 52 ++++++ python/tvm/relay/op/op.py | 70 ++++--- python/tvm/relay/to_tvm.py | 54 +++--- src/relay/ir/op.cc | 172 +++++++++++------- src/relay/pass/type_infer.cc | 1 + .../relay/test_tyck_eval_integration.py | 18 +- 10 files changed, 314 insertions(+), 180 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 2e8d090f6625..756451e66768 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -35,7 +35,7 @@ class OpNode : public relay::ExprNode { /*! \brief name of the operator */ std::string name; /*! \brief the type of the operator */ - Type op_type; + mutable FuncType op_type; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. diff --git a/nnvm/python/nnvm/_base.py b/nnvm/python/nnvm/_base.py index 29390a2201bf..63b2f815ad9b 100644 --- a/nnvm/python/nnvm/_base.py +++ b/nnvm/python/nnvm/_base.py @@ -22,7 +22,12 @@ numeric_types = (float, int, np.float32, np.int32) # this function is needed for python3 # to convert ctypes.char_p .value back to python str - py_str = lambda x: x.decode('utf-8') + def py_str(x): + try: + return x.decode('utf-8') + except: + print(x) + # py_str = lambda x: x.decode('utf-8') else: string_types = basestring numeric_types = (float, int, long, np.float32, np.int32) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index ec0cfd55ad62..1558853c2820 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -10,7 +10,16 @@ from . import _make class ExprBuilder(): + # def convert_args(self, def __call__(self, *args, **kwargs): + converted_args = [] + for arg in args: + import pdb; pdb.set_trace() + if isinstance(arg, Param): + converted_args.append(arg.var) + else: + converted_args.append(arg) + return Call(self, args, None, None) class Expr(NodeBase, ExprBuilder): diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 70d5f09237d8..8b49710f70ec 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -13,9 +13,10 @@ from .expr import Expr from .expr import Function, Let, Call, LocalVar from .expr import GlobalVar, If, Constant -from .type import Type +from .type import Type, TypeParam from .env import Environment from .op import Op +from .op.op import specialize_op # import relay.make as relay_mk # from relay import ir # from relay.env import Environment @@ -126,7 +127,7 @@ def visit_let(self, let: Let) -> Expr: return Let(new_var, new_val, new_body, new_value_type) def visit_call(self, call: Call) -> Expr: - new_fn = self.visit(call.fn) + new_fn = self.visit(call.op) new_args = [self.visit(arg) for arg in call.args] return Call(new_fn, new_args, call.attrs) @@ -148,7 +149,7 @@ def visit_tuple(self, tup: Tuple) -> Expr: def visit_constant(self, const: Constant) -> Expr: return const -MMCacheKey = Tuple[GlobalVar, List[Type]] +MMCacheKey = Tuple[Union[GlobalVar, str], List[Type]] class Monomorphize(ExprVisitor): """A monomorphization pass. @@ -168,60 +169,54 @@ def __init__(self, env: Environment) -> None: # pylint: disable=no-else-return def visit_call(self, call: Call) -> Expr: - import pdb; pdb.set_trace() - # cache_key = (call.fn, call.ty_args) - # if isinstance(call.fn, OperatorId): - # if cache_key in self.monomorph_map: - # op = self.monomorph_map[cache_key] - # new_args = [self.visit(arg) for arg in call.args] - # return Call(op, new_args, call.attrs) - # else: - # new_name = mangle(call.fn.name, call.ty_args) - # new_id = self.env.operator_id(new_name) - # self.monomorph_map[cache_key] = new_id - # op = self.env.lookup(call.fn) - # for arg in call.ty_args: - # if isinstance(arg, TypeParam): - # return call # raise Exception("...") # Fix me in the morning!!! - # new_op = concretize(new_id, op, call.ty_args, call.attrs) - # self.monomorph_map[cache_key] = new_op.id - # self.env.add(new_op) - # new_args = [self.visit(arg) for arg in call.args] - # return Call(new_op.id, new_args, call.attrs) - # elif isinstance(call.fn, GlobalVar): - # if cache_key in self.monomorph_map: - # op_name = self.monomorph_map[cache_key] - # new_args = [self.visit(arg) for arg in call.args] - # return Call(op_name, new_args, call.attrs) - # else: - # defn = self.env.lookup(call.fn) - # new_id = self.env.global_id(defn.id.name + str(1)) - # cache_key = (call.fn, call.ty_args) - # self.monomorph_map[cache_key] = new_id - # new_body = self.visit(type_specialize(call.ty_args, defn.body)) - # new_body = Function( - # [], new_body.params, new_body.ret_type, new_body.body) - # new_ty = check_expr(self.env, new_body) - # # TODO(@jroesch): move into C++ - # # TODO(@joresch): implement and call name mangler - # defn = Defn(new_id, new_ty, new_body) - # self.env.add(defn) - # self.visit_item(defn) - # return Call(new_id, call.args, call.attrs) - # elif isinstance(call.fn, Function): - # new_args = [self.visit(arg) for arg in call.args] - # new_func = type_specialize(call.ty_args, call.fn) - # new_func = self.visit(new_func) - # new_func = Function([], - # new_func.params, - # new_func.ret_type, - # new_func.body) - # check_expr(self.env, new_func) - # return Call(new_func, call.args, call.attrs) - # else: - # new_fn = self.visit(call.fn) - # new_args = [self.visit(arg) for arg in call.args] - # return Call(new_fn, new_args, call.attrs) + cache_key = (call.op, call.type_args) + new_args = [self.visit(arg) for arg in call.args] + + if cache_key in self.monomorph_map: + op = self.monomorph_map[cache_key] + new_args = [self.visit(arg) for arg in call.args] + return Call(op, new_args, call.attrs) + else: + if isinstance(call.op, Op): + poly_name = call.op.name + mono_name = mangle(poly_name, call.type_args) + for arg in call.type_args: + if isinstance(arg, TypeParam): + return call # raise Exception("...") # Fix me in the morning!!! + + mono_op = specialize_op(poly_name, mono_name, call.type_args) + self.monomorph_map[cache_key] = mono_op + return Call(mono_op, new_args,call.attrs, []) + elif isinstance(call.op, GlobalVar): + return call + # defn = self.env.lookup(call.op) + # new_id = self.env.global_id(defn.id.name + str(1)) + # cache_key = (call.op, call.type_args) + # self.monomorph_map[cache_key] = new_id + # new_body = self.visit(type_specialize(call.type_args, defn.body)) + # new_body = Function( + # [], new_body.params, new_body.ret_type, new_body.body) + # new_ty = check_expr(self.env, new_body) + # # TODO(@jroesch): move into C++ + # # TODO(@joresch): implement and call name mangler + # defn = Defn(new_id, new_ty, new_body) + # self.env.add(defn) + # self.visit_item(defn) + # return Call(new_id, call.args, call.attrs) + + elif isinstance(call.op, Function): + return call + # new_func = type_specialize(call.type_args, call.op) + # new_func = self.visit(new_func) + # new_func = Function([], + # new_func.params, + # new_func.ret_type, + # new_func.body) + # check_expr(self.env, new_func) + # return Call(new_func, call.args, call.attrs) + else: + new_fn = self.visit(call.op) + return Call(new_fn, new_args, call.attrs) # TODO(@jroesch): Fix up my type diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 08dedee0923c..da94ec89b380 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -1,4 +1,56 @@ """Backend compiler related feature regsitration""" +from .op import register +from ..type import FuncType, TensorType +from ...schedule import create_schedule +from ...api import placeholder +from topi import add +def type_to_placeholder(name, ty): + if isinstance(ty, TensorType): + return placeholder(ty.shape, name=name, dtype=ty.dtype) + else: + raise Exception("can only pass Tensor values to TVM operators") +def func_ty_to_placeholders(func_ty): + if isinstance(func_ty, FuncType): + arg_types = func_ty.arg_types + ret_type = func_ty.ret_type + args = [] + var = 0 + for arg in arg_types: + var += 1 + args.append(type_to_placeholder(f"Input{var}", arg)) + return args, ret_type + else: + raise Exception("error") +# def lookup_in_topi(name): +# try: +# f = eval(f"topi.{name}") +# except: +# f = eval(f"topi.nn.{name}") + +# return f + +# @tvm.register_func("nnvm.relay._default_op_compiler") +# def _default_op_compile(op_name: str, func_ty: ir.Type, attrs: ir.Attributes=None) -> Any: +# Inputs, ret_ty = func_ty_to_placeholders(func_ty) +# op = lookup_in_topi(op_name) +# Output = op(*Inputs) + +# if Output.dtype == 'uint1': +# import pdb; pdb.set_trace() +# Output = Output.astype('uint8') + +# schedule = tvm.create_schedule(Output.op) +# return [schedule, Inputs + [Output]] + + +def add_compiler(op_name, func_type, *args): + Inputs, ret_ty = func_ty_to_placeholders(func_type) + # op = lookup_in_topi(op_name) + Output = add(*Inputs) + schedule = create_schedule(Output.op) + return [schedule, Inputs + [Output]] + +register("add", "FRelayOpCompiler", add_compiler) \ No newline at end of file diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index d351e6cdc88d..bb589f40f138 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -5,6 +5,8 @@ from ..expr import Expr from ..._ffi.function import Function, register_func from ...api import convert +from ...container import Map +from ... import lower, build, cpu @register_relay_node class Op(Expr): @@ -78,40 +80,64 @@ def compile_ops(op_names): Parameters ---------- - op_name : str - The name of operator - - attr_key : str - The attribute name. - - value : object, optional - The value to set - - level : int, optional - The priority level + op_names : List[str] + A list of operator names to compile to machine code. Returns ------- - fregister : function - Register function if value is not specified. + A module containing the compiled TVM operators. """ - fake_map = {} - for name in op_names: - fake_map[name] = LocalVar(name) - if isinstance({}, dict): - fake_map = None - return [] # _CompileOpsToModule(fake_map) + return _CompileOpsToModule(*op_names) # TODO(@jroesch): We should port to C++, just need to figure out how to write this code. -@register_func("relay.opt.compile_ops") +@register_func("relay.op._compile_ops") def _compile_ops(op_impls): lowered = [] for local, sch, inputs in op_impls: - lfn = tvm.lower(sch, inputs, name=local.name_hint) + lfn = lower(sch, inputs, name=local.name_hint) lowered.append(lfn) # TOOD(@jroesch): Where should we read these settings from - return tvm.build(lowered, target='llvm', target_host=tvm.cpu(0)) + return build(lowered, target='llvm', target_host='llvm') _init_api("relay.op", __name__) +def specialize_op(op_name, new_op_name, type_args): + """Specializes an operator to a set of types and assigns it new_op_name. + + The idea is to take operators with generic types such as broadcasting + addition: + + add : forall (T : Type) (U : Type), (U, T) -> Broadcast(U, T) + + This is a function which is polymorphic over two types `T` and `U` and + takes a value of type `T` and one of `U` and returns `Broadcast` of U + and T. + + Broadcast is a type relation which relates U and T to an output type. + + The idea is that the above type is shorthand for: + + add : forall (T : Type) (U : Type) (O : Type), Broadcast(U, T, O) => (U, T) -> O + + That is a function from U and T to O where the typing relation between the values + is specified by Broadcast. + + We implement a basic Broadcasting rule in `type_relations.h` but users can specify + their own. + + If we know T=Tensor[(10, 10), dtype], U=Tensor[(10, 10), dtype] then the result + should be Tensor[(10, 10), dtype]. + + We can use SpecializeOp to implement this change of operator. + + Parameters + ---------- + op_name : str + The operator to be specialized. + + Returns + ------- + The specialized operator. + """ + return _SpecializeOp(op_name, new_op_name, type_args) \ No newline at end of file diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py index d191e078dffe..f2c2a9ba5463 100644 --- a/python/tvm/relay/to_tvm.py +++ b/python/tvm/relay/to_tvm.py @@ -1,11 +1,11 @@ """A compiler from Relay programs to TVM's graph runtime. """ import json -from typing import Dict, Any, List, Tuple +from typing import Dict, Any, List, Tuple, Set import attr from .ir_pass import AbstractExprVisitor -from .op import compile_ops +from .op import compile_ops, Op from .type import TensorType from .expr import LocalVar, Function, Let, Call @@ -69,23 +69,23 @@ def to_json(self) -> Any: } +def shape_to_json(shape): + return [str(sh.value) for sh in shape] + def from_tensor(typ: TensorType) -> Tuple[str, List[int]]: - dtype = typ.dtype.dtype - shape = typ.shape - dims = [] - for dim in shape.shapes: - dims.append(dim.value) - return dtype, dims + return (typ.dtype, shape_to_json(typ.shape)) class TVMRTSCompiler(AbstractExprVisitor[NodeRef]): """The compiler from Relay to the TVM runtime system.""" nodes: List[Node] id_map: Dict[LocalVar, NodeRef] + all_ops: Set[Op] def __init__(self) -> None: self.nodes = [] self.id_map = {} + self.all_ops = set() def add_node(self, node: Node) -> NodeRef: self.nodes.append(node) @@ -116,11 +116,11 @@ def compile(self, func: Function) -> None: for param in params: dtype, shape = from_tensor(param.type) - node = InputNode(f"{param.id.name}", { + node = InputNode(f"{param.var.name_hint}", { "shape": shape, "dtype": dtype, }) - self.let_bind(param.id, node) + self.let_bind(param.var, node) # Then we compile the body into a graph which can depend # on input variables. @@ -150,7 +150,7 @@ def visit_let(self, let: Let) -> NodeRef: self.add_binding(ident, val_ref) return self.visit(body) - def visit_local_id(self, ident: LocalVar) -> NodeRef: + def visit_local_var(self, ident: LocalVar) -> NodeRef: return self.lookup(ident) def visit_call(self, call: Call) -> NodeRef: @@ -158,9 +158,13 @@ def visit_call(self, call: Call) -> NodeRef: for arg in call.args: inputs.append(self.visit(arg).to_json()) - # need to deal with name mangle - op_name = call.fn.name - op_node = OpNode("call_name", {}, op_name, inputs, {}) + assert isinstance(call.op, Op) + self.all_ops.add(call.op.name) + + op_name = call.op.name + attrs = { 'shape': shape_to_json(call.checked_type().shape), + 'dtype': call.checked_type().dtype } + op_node = OpNode("call_name", attrs, op_name, inputs, {}) return self.add_node(op_node) def to_json(self) -> str: @@ -221,18 +225,16 @@ def compile(func): """Compile a single function to the components needed by the TVM RTS. """ - op_names = [] - - # # Why do I need to call items? - # for op in env.operators(): - # if not Operator_is_generic(op): - # iids.append(op.id) - - # TODO(@jroesch): Need to write test case for this - print("above") - mod = compile_ops(op_names) - print("below") comp = TVMRTSCompiler() comp.compile(func) + op_names = list(comp.all_ops) + mod = compile_ops(op_names) graph_json = comp.to_json() - return graph_json, mod, None # params currently isn't supported by API + try: + import nnvm + graph = nnvm.graph.load_json(graph_json) + except Exception as e: + import traceback + traceback.print_tb(e.__traceback__) + import pdb; pdb.set_trace() + return graph, mod, None # params currently isn't supported by API diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index e02a3163e8e7..64467004a973 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -1,10 +1,11 @@ -#include #include +#include #include #include +#include "./../pass/type_subst.h" -#include #include +#include namespace dmlc { // enable registry @@ -25,7 +26,7 @@ struct OpManager { // global operator counter std::atomic op_counter{0}; // storage of additional attribute table. - std::unordered_map > attr; + std::unordered_map> attr; // frontend functions std::vector frontend_funcs; // get singleton of the @@ -38,8 +39,7 @@ struct OpManager { // find operator by name const Op& Op::Get(const std::string& name) { const OpRegistry* reg = dmlc::Registry::Find(name); - CHECK(reg != nullptr) - << "Operator " << name << " is not registered"; + CHECK(reg != nullptr) << "Operator " << name << " is not registered"; return reg->op(); } @@ -61,8 +61,8 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) { return *it->second.get(); } -void OpRegistry::UpdateAttr( - const std::string& key, TVMRetValue value, int plevel) { +void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, + int plevel) { OpManager* mgr = OpManager::Global(); std::lock_guard lock(mgr->mutex); std::unique_ptr& op_map = mgr->attr[key]; @@ -71,13 +71,11 @@ void OpRegistry::UpdateAttr( } uint32_t index = op_->index_; if (op_map->data_.size() <= index) { - op_map->data_.resize(index + 1, - std::make_pair(TVMRetValue(), 0)); + op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); } - std::pair & p = op_map->data_[index]; + std::pair& p = op_map->data_[index]; CHECK(p.second != plevel) - << "Attribute " << key - << " of operator " << this->name + << "Attribute " << key << " of operator " << this->name << " is already registered with same plevel=" << plevel; if (p.second < plevel) { op_map->data_[index] = std::make_pair(value, plevel); @@ -86,59 +84,57 @@ void OpRegistry::UpdateAttr( // Frontend APIs TVM_REGISTER_API("relay.op._ListOpNames") -.set_body_typed()>([]() { - Array ret; - for (const std::string& name : - dmlc::Registry::ListAllNames()) { - ret.push_back(tvm::Expr(name)); - } - return ret; - }); - -TVM_REGISTER_API("relay.op._GetOp") -.set_body_typed(Op::Get); + .set_body_typed()>([]() { + Array ret; + for (const std::string& name : + dmlc::Registry::ListAllNames()) { + ret.push_back(tvm::Expr(name)); + } + return ret; + }); +TVM_REGISTER_API("relay.op._GetOp").set_body_typed(Op::Get); TVM_REGISTER_API("relay.op._OpGetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - auto op_map = Op::GetAttr(attr_name); - if (op_map.count(op)) { - *rv = op_map[op]; - } - }); - + .set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + *rv = op_map[op]; + } + }); TVM_REGISTER_API("relay.op._Register") -.set_body([](TVMArgs args, TVMRetValue* rv) { - std::string op_name = args[0]; - std::string attr_key = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); - // enable resgiteration and override of certain properties - if (attr_key == "num_inputs" && plevel > 128) { - reg.set_num_inputs(value); - } else if (attr_key == "attrs_type_key" && plevel > 128) { - reg.set_attrs_type_key(value); - } else { - // normal attr table override. - if (args[2].type_code() == kFuncHandle) { - // do an eager copy of the PackedFunc - PackedFunc f = args[2]; - // If we get a function from frontend, avoid deleting it. - OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); - reg.set_attr(attr_key, f, plevel); + .set_body([](TVMArgs args, TVMRetValue* rv) { + std::string op_name = args[0]; + std::string attr_key = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = + OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + reg.set_attrs_type_key(value); } else { - reg.set_attr(attr_key, args[2], plevel); + // normal attr table override. + if (args[2].type_code() == kFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = args[2]; + // If we get a function from frontend, avoid deleting it. + OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); + reg.set_attr(attr_key, f, plevel); + } else { + reg.set_attr(attr_key, args[2], plevel); + } } - } - }); + }); bool IsGeneric(const Op& op) { if (auto ty_func = op.as()) { - return ty_func->type_params.size() == 0; + return ty_func->type_params.size() != 0; } else { return false; } @@ -146,8 +142,8 @@ bool IsGeneric(const Op& op) { using namespace runtime; -Module CompileOpsToModule(const std::vector & op_names) { - PackedFunc compile_ops = GetPackedFunc("relay.op.compile_ops"); +Module CompileOpsToModule(const std::vector& op_names) { + PackedFunc compile_ops = GetPackedFunc("relay.op._compile_ops"); tvm::Array> args; auto compiler_map = Op::GetAttr("FRelayOpCompiler"); @@ -155,12 +151,15 @@ Module CompileOpsToModule(const std::vector & op_names) { for (auto op_name : op_names) { Op op = Op::Get(op_name); - if (IsGeneric(op)) { + if (!IsGeneric(op)) { auto compiler = compiler_map[op]; - tvm::Array pair = - compiler(op->name, op->op_type); - //TODO(@jroesch): I can't pass strings across what should be the interface here. - tvm::Array triple = {LocalVarNode::make(op->name), pair[0], pair[1]}; + std::cout << "ABOVE CALL" << std::endl; + tvm::Array pair = compiler(op->name, op->op_type); + std::cout << "BELOW CALL" << std::endl; + // TODO(@jroesch): I can't pass strings across what should be the + // interface here. + tvm::Array triple = {LocalVarNode::make(op->name), pair[0], + pair[1]}; args.push_back(triple); } else { throw dmlc::Error("it is impossible to compile generic operators."); @@ -177,14 +176,49 @@ Module CompileOpsToModule(const std::vector & op_names) { } TVM_REGISTER_API("relay.op._CompileOpsToModule") -.set_body([](TVMArgs args, TVMRetValue* ret) { - tvm::Map map = args[0]; - std::vector names; - for (auto pair : map) { - names.push_back(pair.first); + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector names; + for (auto i = 0; i < args.num_args; i++) { + names.push_back(args[i]); + } + std::cout << "Right here" << std::endl; + *ret = CompileOpsToModule(names); + }); + +Op SpecializeOp(const std::string& op_name, + const std::string& new_op_name, Array type_args) { + auto registry = ::tvm::relay::OpRegistry::Registry(); + auto op_reg = registry->__REGISTER_OR_GET__(op_name); + auto new_op_reg = registry->__REGISTER__(new_op_name).set_name(); + + auto fn_ty = op_reg.op()->op_type; + + tvm::Map subst_map; + + CHECK(fn_ty->type_params.size() == type_args.size()); + + // Build a subsitituion map up from the function type and type arguments. + // Eventually allow the type vars to be passed in. + for (auto i = 0; i < type_args.size(); i++) { + subst_map.Set(fn_ty->type_params[i], type_args[i]); } - *ret = CompileOpsToModule(names); -}); + + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + inst_ty = TypeSubst(fn_ty, subst_map); + FuncType new_op_ty = GetRef(inst_ty.as()); + new_op_reg.op()->op_type = new_op_ty; + + // Now we want to copy over some attributes. + PackedFunc compiler = Op::GetAttr("FRelayOpCompiler")[op_reg.op()]; + new_op_reg.set_attr("FRelayOpCompiler", compiler); + + return new_op_reg.op(); +} + +TVM_REGISTER_API("relay.op._SpecializeOp") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SpecializeOp(args[0], args[1], args[2]); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 2ea205b511b1..b624a5709ddd 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -210,6 +210,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { } CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { + // We should trigger error here and move param code direclty into function checking. auto rtype = resolve(param->type); // This is a special case ... not sure if there is a better way // to handle this. diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 51833e13e475..9e89e8813e08 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -1,6 +1,9 @@ """Test that type checker correcly computes types for expressions. """ +import tvm +import numpy as np +from nnvm import graph from tvm.relay.ir_pass import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, int_type from tvm.relay.ir_builder import func_type, tensor_type, into_ast @@ -9,6 +12,7 @@ from tvm.relay.op import log, add, equal, subtract from tvm.relay.expr import Function from tvm.relay import to_tvm +from tvm.contrib import graph_runtime def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) @@ -19,14 +23,18 @@ def decl_has_type(env, name, typ): return func.checked_type() == typ -def run(env, expr): +def run(env, expr, inputs, shape): if not isinstance(expr, Function): expr = Function([], None, expr, []) env.add("main", expr) env.transform(Monomorphize.to_pass()) main = env.lookup("main") - graph_json, mod, _ = to_tvm.compile(main) + graph, lib, _ = to_tvm.compile(main) + module = graph_runtime.create(graph, lib, tvm.cpu(0)) + module.set_input(None, None, **inputs) + module.run() + out = module.get_output(0, out=tvm.nd.array(shape)) import pdb; pdb.set_trace() def test_monomorphic_let(): @@ -59,13 +67,15 @@ def test_binary_op(): x = b.param('x', tensor_type(5, 5, 5)) y = b.param('y', tensor_type(5, 5, 5)) with b.function(x, y) as func: - b.ret(add(x, y)) + b.ret(add(x.var, y.var)) b.ret(func) prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert has_type(func.to_func(), expected_ty) - run(env, prog) + x_data = np.random.rand(5, 5, 5) + y_data = np.random.rand(5, 5, 5) + run(env, prog, { 'x': x_data, 'y': y_data }, (5, 5, 5)) def test_dual_op(): """Program: From be285d625badcbb06445eab8061600fe97c7e81c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 13:07:51 -0700 Subject: [PATCH 72/88] Add another test case and do a little clean up --- python/tvm/relay/to_tvm.py | 16 +- src/relay/pass/type_infer.cc | 142 ------------------ .../relay/test_tyck_eval_integration.py | 64 ++++++-- 3 files changed, 56 insertions(+), 166 deletions(-) diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py index f2c2a9ba5463..181251844a6d 100644 --- a/python/tvm/relay/to_tvm.py +++ b/python/tvm/relay/to_tvm.py @@ -70,7 +70,8 @@ def to_json(self) -> Any: def shape_to_json(shape): - return [str(sh.value) for sh in shape] + return [sh.value for sh in shape] + def from_tensor(typ: TensorType) -> Tuple[str, List[int]]: return (typ.dtype, shape_to_json(typ.shape)) @@ -162,8 +163,8 @@ def visit_call(self, call: Call) -> NodeRef: self.all_ops.add(call.op.name) op_name = call.op.name - attrs = { 'shape': shape_to_json(call.checked_type().shape), - 'dtype': call.checked_type().dtype } + attrs = {'shape': shape_to_json(call.checked_type().shape), + 'dtype': call.checked_type().dtype} op_node = OpNode("call_name", attrs, op_name, inputs, {}) return self.add_node(op_node) @@ -230,11 +231,4 @@ def compile(func): op_names = list(comp.all_ops) mod = compile_ops(op_names) graph_json = comp.to_json() - try: - import nnvm - graph = nnvm.graph.load_json(graph_json) - except Exception as e: - import traceback - traceback.print_tb(e.__traceback__) - import pdb; pdb.set_trace() - return graph, mod, None # params currently isn't supported by API + return graph_json, mod, None # params currently isn't supported by API diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index b624a5709ddd..6cc73d1b8fbe 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -138,8 +138,6 @@ class TypeInferencer : private ExprFunctor { Type resolve(const Type &t); Expr resolve(const Expr &e); CheckedExpr VisitFunction(const Function &f, bool generalize); - void CheckOp(Op op); - // Defn CheckDefn(Defn def); private: CheckedExpr VisitExpr_(const LocalVarNode *op) override; CheckedExpr VisitExpr_(const GlobalVarNode *op) override; @@ -218,43 +216,6 @@ CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { return {ParamNode::make(param->var, rtype), rtype}; } -// // We should probably generalize the subst code. -// struct GeneralizeTypeType : TypeFVisitor { -// Map vars_to_id; -// const TypeUnifier &unifier; - -// GeneralizeTypeType(Map vars_to_id, -// const TypeUnifier &unifier) -// : vars_to_id(vars_to_id), unifier(unifier) {} - -// Type VisitType_(const TypeVarNode *op) override { -// auto repr = unifier->subst(GetRef(op)); -// if (auto tvn = repr.as()) { -// auto ty_var = GetRef(tvn); -// if (vars_to_id.find(ty_var) != vars_to_id.end()) { -// return vars_to_id[ty_var]; -// } else { -// return ty_var; -// } -// } else { -// return this->VisitType(repr); -// } -// } -// }; - -// struct GeneralizeTypeExpr : ExprFVisitor<> { -// Map vars_to_id; -// const TypeUnifier &unifier; - -// GeneralizeTypeExpr(const TypeUnifier &unifier, -// Map vars_to_id) -// : vars_to_id(vars_to_id), unifier(unifier) {} - -// Type VisitType(const Type &t) { -// return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); -// } -// }; - CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { // First we add the parameters to the context allowing us to check their // types. @@ -282,83 +243,6 @@ CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { return {FunctionNode::make(params, unified_rtype, checked_body.expr, {}), FuncTypeNode::make(param_types, unified_rtype, {}, {})}; }); - - // // typecheck body and ensure that it matches stated return type - // // TODO(sslyu): should the unified return type override the annotated - // one? Type checked_return = this->Check(f->body); Type ret_type = - // resolve(f->ret_type); Type unified = - // this->unify(simple_eval_shape(ret_type), - // simple_eval_shape(checked_return), f->span); - // return TypeArrowNode::make(arg_types, unified); - // }); - // if (generalize) { - // auto free_vars = free_type_vars(resolve(fn_type)); - // std::set dedup_free_vars; - - // for (auto free_var : free_vars) { - // auto repr = this->unifier->subst(free_var); - // if (auto new_free_var_node = repr.as()) { - // dedup_free_vars.insert(GetRef(new_free_var_node)); - // } else { - // // debug(repr); - // throw dmlc::Error( - // "internal error: this list should only contain type var - // nodes"); - // } - // } - - // Map vars_to_id; - - // GenFresh gf; - // for (auto free_var : dedup_free_vars) { - // vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); - // } - - // fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); - // for (std::pair pair : vars_to_id) { - // // NB: In generalization we want to find type variables with - // // *no constraints* on them, and convert them to universally - // quantified - // // variables. - // // - // // i.e the program can be abstracted over the details of *that* type. - - // // For example a program that works irrespective of shape or - // datatype. - - // // In order to do this we find the set of free type variables in the - // // term, and then unify them with the fresh type ids we generate. - // // - // // Remember importantly these type variables still may appear in many - // // places in the program including both types and expressions. - - // // Our method for resolving these is to unify them with the variables - // // as we build the new quanitifer, changing from a program with - // "holes" - // // to one that is properly abstracted over. - - // // Finally later on we can iterate over the whole term and change - // from - // // type variables to these type ids. - // this->unify(pair.first, pair.second, pair.second->span); - // fn_type = TypeQuantifierNode::make(pair.second, fn_type); - // } - // } else { - // for (auto i = f->ty_params.size(); i > 0; i--) { - // auto ty_param = f->ty_params[i - 1]; - // auto ty_param_node = ty_param.as(); - // if (!ty_param_node) { - // throw dmlc::Error("internal error should be TypeParam"); - // } - // auto fresh_tid = - // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); - // fn_type = - // TypeSubst(fn_type, GetRef(ty_param_node), fresh_tid); - // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); - // } - // } - - // return fn_type; } CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { @@ -520,32 +404,6 @@ Expr TypeInferencer::resolve(const Expr& e) { return ::tvm::relay::Resolve(this->unifier, e); } -// Defn TypeInferencer::CheckDefn(Defn defn) { -// // This is to handle recursion, but we need to speculatively -// // put it in env, then remove it. -// env->items.insert({defn->id, defn}); - -// Type expected_ty = this->resolve(defn->type); - -// Expr body = defn->body; - -// auto checked_ty = Check(body); - -// try { -// Type uret_type = unify(expected_ty, checked_ty, defn->body->span); -// CHECK(is_fully_resolved(uret_type)); -// // Now let's clean up our work from earlier. -// env->items.erase(defn->id); -// return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); -// } catch (const UnificationError& err) { -// std::string msg = std::string("mismatch between `") + -// PrintType(env, expected_ty, WrapWidth(40)) + "` and -// `" + PrintType(env, checked_ty, WrapWidth(40)) + -// "`"; -// fatal_error(msg, defn->span); -// } -// } - Expr InferType(const Environment &env, const Expr &e) { TypeInferencer ti(env); auto checked_expr = ti.Infer(e); diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 9e89e8813e08..cd87fb83ec52 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -2,7 +2,7 @@ for expressions. """ import tvm -import numpy as np +import numpy as np from nnvm import graph from tvm.relay.ir_pass import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, int_type @@ -13,15 +13,18 @@ from tvm.relay.expr import Function from tvm.relay import to_tvm from tvm.contrib import graph_runtime +import nnvm + def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) return checked_expr.checked_type() == typ + def decl_has_type(env, name, typ): func = env.lookup(name) return func.checked_type() == typ - + def run(env, expr, inputs, shape): if not isinstance(expr, Function): @@ -31,11 +34,14 @@ def run(env, expr, inputs, shape): env.transform(Monomorphize.to_pass()) main = env.lookup("main") graph, lib, _ = to_tvm.compile(main) - module = graph_runtime.create(graph, lib, tvm.cpu(0)) + # We use NNVM to load the graph right now because it populates node_row_ptr field. + nnvm_graph = nnvm.graph.load_json(graph) + module = graph_runtime.create(nnvm_graph, lib, tvm.cpu(0)) module.set_input(None, None, **inputs) module.run() - out = module.get_output(0, out=tvm.nd.array(shape)) - import pdb; pdb.set_trace() + out_nd_array = tvm.nd.array(np.empty(shape, dtype='float32')) + return module.get_output(0, out=out_nd_array) + def test_monomorphic_let(): "Program: let x = 1; return x" @@ -45,7 +51,8 @@ def test_monomorphic_let(): prog, env = b.get() assert has_type(prog, float_type(64)) - run(env, prog) + run(env, prog, [], float_type(64)) + def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" @@ -56,7 +63,8 @@ def test_single_op(): b.ret(t1) assert has_type(func.to_func(), func_type([float_type()], float_type())) -def test_binary_op(): + +def test_add_op(): """ Program: fn (x, y) { @@ -73,9 +81,34 @@ def test_binary_op(): ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert has_type(func.to_func(), expected_ty) - x_data = np.random.rand(5, 5, 5) - y_data = np.random.rand(5, 5, 5) - run(env, prog, { 'x': x_data, 'y': y_data }, (5, 5, 5)) + x_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + y_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 5, 5)) + np.testing.assert_allclose( + x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) + +def test_add_broadcast_op(): + """ + Program: + fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { + return x + y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(10, 4)) + y = b.param('y', tensor_type(5, 10, 1)) + with b.function(x, y) as func: + b.ret(add(x.var, y.var)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert has_type(func.to_func(), expected_ty) + x_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + y_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 10, 4)) + np.testing.assert_allclose( + x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) def test_dual_op(): """Program: @@ -84,7 +117,7 @@ def test_dual_op(): let t2 = add(t1, x); return t1; } - """ + """ b = IRBuilder() with b.function(('x', tensor_type(10, 10))) as func: x, = func.param_ids() @@ -109,6 +142,7 @@ def f(x : Tensor[f32, (10, 10)]) { _, env = b.get() assert decl_has_type(env, 'f', func_type([float_type()], float_type())) + def test_recursion(): """ Program: @@ -131,12 +165,16 @@ def f(n: i32, data: f32) -> f32 { with b.else_scope(): b.ret(data) b.ret(f(into_ast(2.0), into_ast(10000.0))) - assert decl_has_type(b.env, 'f', func_type([int_type(), float_type()], float_type())) + assert decl_has_type(b.env, 'f', func_type( + [int_type(), float_type()], float_type())) + # TODO(@jroesch): need evaluator or new runtime + # to execute this. if __name__ == "__main__": # test_monomorphic_let() # test_single_op() - test_binary_op() + test_add_op() + test_add_broadcast_op() # test_dual_op() # test_decl() # test_recursion() From d47f637a14b4cbfcbd356cb55df78df3b5093b88 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 13:08:39 -0700 Subject: [PATCH 73/88] Port docs from previous Relay version --- docs/api/python/relay/index.rst | 17 ++ docs/langref/relay/expressions.rst | 178 +++++++++++++++++++++ docs/langref/relay/index.rst | 17 ++ docs/langref/relay/intro.rst | 17 ++ docs/langref/relay/type_system.rst | 137 ++++++++++++++++ tutorials/relay/implement_fma_transform.py | 141 ++++++++++++++++ 6 files changed, 507 insertions(+) create mode 100644 docs/api/python/relay/index.rst create mode 100644 docs/langref/relay/expressions.rst create mode 100644 docs/langref/relay/index.rst create mode 100644 docs/langref/relay/intro.rst create mode 100644 docs/langref/relay/type_system.rst create mode 100644 tutorials/relay/implement_fma_transform.py diff --git a/docs/api/python/relay/index.rst b/docs/api/python/relay/index.rst new file mode 100644 index 000000000000..32db5daded2b --- /dev/null +++ b/docs/api/python/relay/index.rst @@ -0,0 +1,17 @@ +Relay API +========= + +This document contains the Python API to the Relay frontend, optimizer, and +compiler toolchain. + +Relay is a new high level intermediate representation for the TVM compiler +stack. Our goal is to generalize computation graphs provided by previous +languages to full differentiable programs. + +.. toctree:: + :maxdepth: 2 + + env + ir + make + unifier diff --git a/docs/langref/relay/expressions.rst b/docs/langref/relay/expressions.rst new file mode 100644 index 000000000000..37dc62c6bc24 --- /dev/null +++ b/docs/langref/relay/expressions.rst @@ -0,0 +1,178 @@ +================== +Expressions +================== + +Relay's IR is a pure expression oriented language, that has a +dataflow fragment and structured control flow. Although Relay's +representation is a tree, it is possible to view the dataflow +fragments as graph for purposes of writing and expressing +transformations. + +The below sections make an attempt to clearly split the dataflow +fragment from the control fragment. + +================== +Dataflow Expressions +================== + +First we will cover the set of nodes which do not involve control flow, +this fragment of the language is semantically equivalent to pure +computation graphs without control flow. + +Constants +~~~~~~~~~ +Relay programs can contain constant Tensor values, since in Relay all +values are either Tensors, Products, or Closures. We will discuss the +later two later, but we represent Tensor constants as `tvm.NDArray`, +allowing us to utilize normal operators for constant evaluation. + + +Constructors +~~~~~~~~ + +Relay supports a handful of constructors which we will cover below. A +constructor enables programs to build new values from arbitrary Relay +expressions. + + +We support four types of literals, literals are type polymorphic and can +assigned any base type. If we can not solve for a concrete type we apply +a defaulting rule. + +We support signed and unsigned integers, floating point numbers, booleans, +and tensor literals. + +The base type literals are designed to closely model literals in TVM's +expressions langauge. + +### Boolean Literals +TODO: don't have these in any form right now + +### Integer Literals +TODO: don't have these in any form right now + +Tensor Constructor +~~~~~~~~~~~~~~~ + +A tensor literal allows us to build a Tensor from other expressions. + +TODO: Example here + + +Tuple Constructor +~~~~~~~~~~~~~~~ + +We support tuple constructors which allows us to build a fixed-k sized +sequence of heterogenous data. These tuples match closely to Python's +and enable efficient projection of their members due to their fixed length. + + (a, b, c) : Tuple + + (a + b + c, d) : Tuple, Tensor> + +Function +~~~~~~~~ + +A function node represents a function, it contains a seqeuence of +parameters, a return type, and a body. + + fun (x : Float, y: Float) -> Float { x + y } + +Functions are first class in Relay, and can be used in any expression +position. Functions are the same as global functions, but do not have +an explicit name. You can use a function in conjunction with a let +binding to define locally recursive functions. + + let fact = fun (x : Float) -> Float { + if (x == 0) { + 0 + } else { + x * fact(x - 1) + }; + fact(10) + +Identifiers +~~~~~~~~~~~ + +All of the identifiers are valid expressions, you can use a local identifier, +global identifier, or intrinsic identifier anywhere an expression may appear. + +For example the below fragment of code is a valid expression. + + %ret = @global(intrinsic, %local) + +Let Binding +~~~~~~~~~~~ + +An immutable variable binding, allows the user to bind an +expression to a name. A let binding contains a local identifier, +an optional type, a value, and body expression which may +reference the bound identifier. + +We will first introduce a single binding with no type +anntoations:: + let %x = %a + %b; + x + +The value of a let binding is the value of the final expression +after evaluating the bindings it depends on. + +A user can write a sequence of let bindings, we can view +these blocks and pure dataflow +single binding. These blocks are pure dataflow, and can +be evaluated in any order, reordered up to dataflow. + +We support a sequence of bindings followed by a body which +is the continutation after executing the sequence of bindings. + +I believe this representation will be easier to manipulate then +the mixed dataflow/control flow comptuation graphs. +Data flow and control flow are strictly seperated in this representation +and we can easily syntactically discriminate. When in ANF there should only be +general control flow between `Assignment` nodes and not within the values bound +in bindings. + +This representation also makes it easy to apply reverse more since +sequences of assignments where the only control flow is call instructions +are treated by the algorithm uniformly, and each control flow construct +must be handled individualy. + +TODO Add Ref, ReadRef, WriteRef, Projection, + +Gradient +~~~~~~~~ + +The `Reverse` acts as a marker node, when the compiler encounters it +we will apply the reverse mode transformation to the enclosed function. + +We will employ static analysis and constant evaluation in order to +simplify the node's argument to a known function call target. + + +You can compute the reverse node of a function node like so: + +Cast +~~~~~ + +Cast the type of the `node` to `ty`. + +======================= +Control Flow Expression +======================= +Control flow expressions change network topology based on values +computed by previous expressions. + +Call +~~~~ + +Terms with function types in Relay are "callable", that can be invoked like +a function in a typical programming language by supplying a set of arguments. + +Instrinsics with functions types, definitions, and functions are all callable. + +If-Then-Else +~~~~~~~~~~~~ + +Relay has a simple if/then/else expression which allows programs to branch +on a single control value which must be of type `Bool`, i.e a zero-rank +tensor of booleans. diff --git a/docs/langref/relay/index.rst b/docs/langref/relay/index.rst new file mode 100644 index 000000000000..617e745acdfc --- /dev/null +++ b/docs/langref/relay/index.rst @@ -0,0 +1,17 @@ +Relay Language Reference +======================== + +This document is a work in progress language reference describing +Relay, TVM's high level intermediate representation. The name is an +allusion to interneurons which are often referred to as intermediate, +or relay neurons. + +We will continually iterate on this document as we evolve the new IR +and update accordingly. + +.. toctree:: + :maxdepth: 2 + + intro + expressions + type_system diff --git a/docs/langref/relay/intro.rst b/docs/langref/relay/intro.rst new file mode 100644 index 000000000000..617e745acdfc --- /dev/null +++ b/docs/langref/relay/intro.rst @@ -0,0 +1,17 @@ +Relay Language Reference +======================== + +This document is a work in progress language reference describing +Relay, TVM's high level intermediate representation. The name is an +allusion to interneurons which are often referred to as intermediate, +or relay neurons. + +We will continually iterate on this document as we evolve the new IR +and update accordingly. + +.. toctree:: + :maxdepth: 2 + + intro + expressions + type_system diff --git a/docs/langref/relay/type_system.rst b/docs/langref/relay/type_system.rst new file mode 100644 index 000000000000..91a634431d7c --- /dev/null +++ b/docs/langref/relay/type_system.rst @@ -0,0 +1,137 @@ +================== +Type System +================== + +We have briefly introduced types while detailing the the expression language +of Relay, but have fully laid out the type system. + +Although the majority of Relay programs require no type annotations, Relay +is statically typed. Each expression in Relay has a precisely known type. + +You might ask why we want a statically typed IR, there are multiple advantages. +- efficient layout and code generation for tensors +- TODO +- debugging transformations (most program transformations should be type perserving) + +We are able to omit these type annotations by a process known as type inference. +Type inference is a technique that has its roots in the programming language +community, and can be viewed as a method for generalizing shape inference to +run over arbitrary user programs. + +Static typing means we know before executing the program properties about +the values it manipulates. Static types are useful for compiler optimization +because they communicate properties about the data we manipulate, such as +runtime shape, data layout, storage. + +Most current IRs use "shape inference" to recover Tensor dimensions from the user +provided program. Machine learning users have enjoyed shape inference for +tensors because it allows them to generate performant code without giving up +on the expressivity of the input language. + +Because Relay is intended as an IR we require *some* type information to provide +full inference. We don't believe this to be an issue as many of the IR builder +inferfaces require some type information, or can generate IR based on their own +higher level inferences. + +We view this limited shape inference as a simpler form of type +inference. Instead of relying on an ad-hoc procedure for recovering type +information from a potentially dynamic program, we apply ideas from compiler and IR design. + +Below we briefly dicsuss the different kinds of types in Relay. + +===== +Types +===== + +BaseType +~~~~~~~~~~ +Relay has a notion of a BaseType, which captures the set of types +that can be stored in a Tensor. Relay's base types map to the set +of types supported by TVM. + +Each of the base types can be parametrized by number of bits, and +lanes for vectorization purposes. We support four base types any:`Bool`, +any:`Int` + +Type Variables +~~~~~~~~~~~~~~ + +Type Parameters +~~~~~~ +TODO: type parameter + +Kind +~~~~ + +Function Types +~~~~~~~~~~ +TODO: rename function type? + +TypeQuantifier +~~~~~~~~~~~~~~ +TODO + +Placeholders +~~~~~~~~~~~~ + +TODO + +Tuple Types +~~~~~~~~~~~~~ + +Reference Types +~~~~~~~~~~~~~~~ + +A reference type is simply a mutable memory location, since Relay is a pure +language by default we need a way to introduce limited mutability. In this +case mutable data is clearly marked in the type system as a reference type. + + Ref + +Tensor Type +~~~~~~~~~~~ + +Tensor values in Relay are typed with tensor types. A tensor type is +parametrized by a data type, and shape. The data type must be a base +type as enforced by the kind checking rules described in TODO. + +This restriction importantly means + +The shape may be any valid Relay shape as described in the below +section on shapes. + + +====== +Shapes +====== + +Shape Singleton +~~~~~~~~~~~~~~~ +I don't like this name + +ShapeAttr +~~~~~~~~~ +TODO + +ShapeProjection +~~~~~~~~~~~~~~~ +TODO + +ShapeBinaryOp +~~~~~~~~~~~~~ + +enum ShapeOp : int { + SHPLUS = 0, + SHSUB = 1, + SHMUL = 2, + SHDIV = 3 +}; + + +Shape Sequence +~~~~~~~~ +A sequence of shapes ... + + +ShapeBroadcast +~~~~~~~~~~~~~~ diff --git a/tutorials/relay/implement_fma_transform.py b/tutorials/relay/implement_fma_transform.py new file mode 100644 index 000000000000..8410dd6c1152 --- /dev/null +++ b/tutorials/relay/implement_fma_transform.py @@ -0,0 +1,141 @@ +"""How to use Relay to implement a simple two-operator fusion pass. +================================== +**Author**: `Jared Roesch `_ + +In this tutorial, we will demonstrate how to write a fusion pass for +the Relay IR. We demonstrate many Relay features including defining a +new operator, a program transform, the NNVM compatibility layer, +and executing the original and transformed programs on the Relay +evaluator and TVM runtime system. +""" + +################################################################ +# Introduction +# ------------------------- +# +# We use the fixed size for input tensors with 256 channels and 14 x 14 +# dimensions. The batch size is 256. Convolution filters contain 512 filters +# of size 3 x 3. We use stride size 1 and padding size 1 for the +# convolution. The following code defines the convolution algorithm in TVM. +# + +from typing import Any, Dict + +import numpy as np +import tvm +import topi + +from relay import ir, make as mk +from relay.ir import OperatorId +from relay.opt import ItemVisitor, ExprVisitor +from relay.frontend.nnvm import Variable, symbol +from relay.frontend.nnvm import compiler +from relay.frontend.global_env import get_env +from relay.operators.register import func_ty_to_placeholders, register_op +from relay.eval import defn_to_pyfunc +from relay.tyck import check_expr + +class ExprAtVisitor(ExprVisitor): + """A demo visitor which adds a new traversal strategy.""" + expr_map: Dict[ir.LocalId, ir.Expr] + + def __init__(self): + self.expr_map = {} + + def expr_at(self,id: ir.LocalId) -> ir.Expr: + try: + return self.expr_map[id] + except KeyError: + return id + + def visit_let(self, let: ir.Let) -> ir.Expr: + self.expr_map[let.id] = let.value + return super().visit_let(let) + +# let x = 1 + 1; +# ... x will map to 1 + 1 + +class FuseTwo(ExprAtVisitor): + """Rewrite b(a(x, y), z) into ab(x, y, z). """ + def __init__(self, a: OperatorId, b: OperatorId, a_and_b: OperatorId) -> None: + self.a = a + self.b = b + self.a_and_b = a_and_b + super().__init__() + + def visit_call(self, call: ir.Call) -> ir.Expr: + func = call.fn + if func == self.b: + assert len(call.args) == 2 # An assumption of this fusion code. + arg0 = self.expr_at(call.args[0]) + arg1 = self.expr_at(call.args[1]) + if isinstance(arg0, ir.Call) and arg0.fn == self.a: + new_call = mk.Call(self.a_and_b, arg0.args[:] + [arg1]) + elif isinstance(arg1, ir.Call) and arg1.fn == self.a: + new_call = mk.Call(self.a_and_b, arg1.args[:] + [arg0]) + else: + new_call = super().visit_call(call) + + return new_call + else: + return super().visit_call(call) + +def fma_compile(op_name: str, func_ty: ir.Type, attrs: ir.Attributes=None) -> Any: + Inputs, ret_ty = func_ty_to_placeholders(func_ty) + x, y, z = Inputs + Output = topi.multiply(topi.add(x, y), z) + # this is not a python function call, but builds an AST + schedule = tvm.create_schedule(Output.op) + return [schedule, Inputs + [Output]] + + +def register_fma(env: Any) -> None: + """Register TOPI's elementwise broadcast addition for the `+` operator.""" + shape = mk.TypeParam("s", ir.Kind.Shape) + bt = mk.TypeParam("bt", ir.Kind.BaseType) + in_out_type = mk.TensorType(bt, shape) + fma_type = mk.TypeQuantifier(bt, mk.TypeQuantifier(shape, mk.TypeArrow([in_out_type, in_out_type, in_out_type], in_out_type))) + # forall (bt: BaseTYpe) (s : Shape), Tensor[bt, s] -> Tensor[bt, s] -> Tensor[bt, s] + # TODO: no reverse mode + register_op(env, 'fma', fma_type, compiler=fma_compile) + +# Get the global environment for demo purposes. +env = get_env() + +register_fma(env) + +# A small helper which applies just our transform to the Relay expression. +def transform(e): + fuse = FuseTwo(env.add_id(), env.mul_id(), env.operator_id('fma')) + e = fuse.visit(e) + # Now let's use the type checker to make sure we didn't make a mistake. + check_expr(env, e) + return e + +# We will use NNVM frontend. +x = Variable('x') +y = Variable('y') +z = x * (x + y) + +relay_func = compiler.to_relay(z) + +print(f"Relay Function:\n{compiler.pp(relay_func)}") + +xform_func = transform(relay_func) + +print(f"Transformed Function:\n{compiler.pp(xform_func)}") + +# Use the evaluator. +norm = defn_to_pyfunc(env, relay_func) +xform = defn_to_pyfunc(env, xform_func) + +x = np.random.uniform(size=(10, 5, 10)).astype('float32') +y = np.random.uniform(size=(10, 5, 10)).astype('float32') + +norm_out = norm(x, y).asnumpy() +xform_out = xform(x, y).asnumpy() + +np.testing.assert_allclose(norm_out, xform_out) + +# Use the TVM runtime. + From 71dcd1ed74541fddef05991141bb40966469701c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 17:10:28 -0700 Subject: [PATCH 74/88] Update docs --- docs/api/python/index.rst | 1 + docs/api/python/relay/base.rst | 9 +++ docs/api/python/relay/env.rst | 6 ++ docs/api/python/relay/expr.rst | 36 +++++++++++ docs/api/python/relay/index.rst | 15 +++-- docs/api/python/relay/ir_builder.rst | 6 ++ docs/api/python/relay/ir_pass.rst | 3 + docs/api/python/relay/op.rst | 3 + docs/api/python/relay/to_tvm.rst | 3 + docs/api/python/relay/type.rst | 27 ++++++++ python/tvm/relay/type.py | 72 +++++++++++++++++++++- tutorials/relay/implement_fma_transform.py | 10 +-- 12 files changed, 178 insertions(+), 13 deletions(-) create mode 100644 docs/api/python/relay/base.rst create mode 100644 docs/api/python/relay/env.rst create mode 100644 docs/api/python/relay/expr.rst create mode 100644 docs/api/python/relay/ir_builder.rst create mode 100644 docs/api/python/relay/ir_pass.rst create mode 100644 docs/api/python/relay/op.rst create mode 100644 docs/api/python/relay/to_tvm.rst create mode 100644 docs/api/python/relay/type.rst diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index 59bd1795b7ec..ab411d77f4f4 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -23,4 +23,5 @@ Python API topi vta/index nnvm/index + relay/index hybrid diff --git a/docs/api/python/relay/base.rst b/docs/api/python/relay/base.rst new file mode 100644 index 000000000000..f0cec295ee6b --- /dev/null +++ b/docs/api/python/relay/base.rst @@ -0,0 +1,9 @@ +tvm.relay.base +----------- +.. automodule:: tvm.relay.base + +.. autoclass:: tvm.relay.base.NodeBase + :members: + +.. autoclass:: tvm.relay.base.Span + :members: \ No newline at end of file diff --git a/docs/api/python/relay/env.rst b/docs/api/python/relay/env.rst new file mode 100644 index 000000000000..eca7312d5bbb --- /dev/null +++ b/docs/api/python/relay/env.rst @@ -0,0 +1,6 @@ +tvm.relay.env +----------- +.. automodule:: tvm.relay.env + +.. autoclass:: tvm.relay.env.Environment + :members: \ No newline at end of file diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst new file mode 100644 index 000000000000..cd0cb5c308c4 --- /dev/null +++ b/docs/api/python/relay/expr.rst @@ -0,0 +1,36 @@ +tvm.relay.expr +----------- +.. automodule:: tvm.relay.expr + +.. autoclass:: tvm.relay.expr.ExprBuilder + :members: + +.. autoclass:: tvm.relay.expr.Expr + :members: + +.. autoclass:: tvm.relay.expr.Constant + :members: + +.. autoclass:: tvm.relay.expr.Tuple + :members: + +.. autoclass:: tvm.relay.expr.LocalVar + :members: + +.. autoclass:: tvm.relay.expr.GlobalVar + :members: + +.. autoclass:: tvm.relay.expr.Param + :members: + +.. autoclass:: tvm.relay.expr.Function + :members: + +.. autoclass:: tvm.relay.expr.Call + :members: + +.. autoclass:: tvm.relay.expr.Let + :members: + +.. autoclass:: tvm.relay.expr.If + :members: \ No newline at end of file diff --git a/docs/api/python/relay/index.rst b/docs/api/python/relay/index.rst index 32db5daded2b..231d49df0e6d 100644 --- a/docs/api/python/relay/index.rst +++ b/docs/api/python/relay/index.rst @@ -4,14 +4,17 @@ Relay API This document contains the Python API to the Relay frontend, optimizer, and compiler toolchain. -Relay is a new high level intermediate representation for the TVM compiler -stack. Our goal is to generalize computation graphs provided by previous -languages to full differentiable programs. +Relay is the second generation high level intermediate representation for the TVM +compiler stack. .. toctree:: :maxdepth: 2 + base env - ir - make - unifier + expr + ir_builder + ir_pass + op + to_tvm + type diff --git a/docs/api/python/relay/ir_builder.rst b/docs/api/python/relay/ir_builder.rst new file mode 100644 index 000000000000..b12e3cc6cdd1 --- /dev/null +++ b/docs/api/python/relay/ir_builder.rst @@ -0,0 +1,6 @@ +tvm.relay.ir_builder +----------- +.. automodule:: tvm.relay.ir_builder + +.. autoclass:: tvm.relay.ir_builder.IRBuilder + :members: \ No newline at end of file diff --git a/docs/api/python/relay/ir_pass.rst b/docs/api/python/relay/ir_pass.rst new file mode 100644 index 000000000000..e2e3b432e5bd --- /dev/null +++ b/docs/api/python/relay/ir_pass.rst @@ -0,0 +1,3 @@ +tvm.relay.ir_pass +----------- +.. automodule:: tvm.relay.ir_pass \ No newline at end of file diff --git a/docs/api/python/relay/op.rst b/docs/api/python/relay/op.rst new file mode 100644 index 000000000000..fb8e9ce774c2 --- /dev/null +++ b/docs/api/python/relay/op.rst @@ -0,0 +1,3 @@ +tvm.relay.op +----------- +.. automodule:: tvm.relay.op \ No newline at end of file diff --git a/docs/api/python/relay/to_tvm.rst b/docs/api/python/relay/to_tvm.rst new file mode 100644 index 000000000000..72d01d123e0f --- /dev/null +++ b/docs/api/python/relay/to_tvm.rst @@ -0,0 +1,3 @@ +tvm.relay.to_tvm +----------- +.. automodule:: tvm.relay.to_tvm diff --git a/docs/api/python/relay/type.rst b/docs/api/python/relay/type.rst new file mode 100644 index 000000000000..d357df8f08ac --- /dev/null +++ b/docs/api/python/relay/type.rst @@ -0,0 +1,27 @@ +tvm.relay.type +----------- +.. automodule:: tvm.relay.type + +.. autoclass:: tvm.relay.type.Type + :members: + +.. autoclass:: tvm.relay.type.TensorType + :members: + +.. autoclass:: tvm.relay.type.Kind + :members: + +.. autoclass:: tvm.relay.type.TypeParam + :members: + +.. autoclass:: tvm.relay.type.TypeConstraint + :members: + +.. autoclass:: tvm.relay.type.FuncType + :members: + +.. autoclass:: tvm.relay.type.TypeCall + :members: + +.. autoclass:: tvm.relay.type.IncompleteType + :members: \ No newline at end of file diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index 70e4666e96f9..cde989603929 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -6,11 +6,12 @@ from tvm import expr from . import _make + class Type(NodeBase): """The base type for all Relay types.""" def __eq__(self, other) -> bool: - """Compares two Relay types for structural equivalence using + """Compare two Relay types for structural equivalence using alpha equivalence. """ return bool(_make._type_alpha_eq(self, other)) @@ -22,46 +23,97 @@ def same_as(self, other) -> bool: """Compares two Relay types by referential equality.""" return super().__eq__(other) + @register_relay_node class TensorType(Type): """A concrete TensorType in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to tensor's with a known dype and shape. For + example a tensor of `float32` and `(5, 5)`. """ shape: List[expr.Expr] dtype: str span: Span def __init__(self, shape: List[expr.Expr], dtype: str) -> None: + """Construct a tensor type. + + Parameters + ---------- + shape: list of tvm.Expr + dtype: str + + Returns + ------- + tensor_type: The TensorType + """ self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) + class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, base type, type, or dimension. + + This controls what a type parameter is allowed to be instantiated + with. For example one's of kind BaseType can only be `float32`, `int32`, + and so on. """ ShapeVar = 0 Shape = 1 BaseType = 1 Type = 2 + @register_relay_node class TypeParam(Type): """A type parameter used for generic types in Relay, see tvm/relay/type.h for more details. + + A type parameter represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. """ var: expr.Var kind: Kind span: Span def __init__(self, var: expr.Var, kind: Kind) -> None: + """Construct a TypeParam. + + Parameters + ---------- + var: tvm.expr.Var + The tvm.Var which backs the type parameter. + + kind: Kind + The kind of the type parameter. + + Returns + ------- + type_param: TypeParam + The type parameter. + """ self.__init_handle_by_constructor__(_make.TypeParam, var, kind) + @register_relay_node class TypeConstraint(Type): """Abstract class representing a type constraint.""" pass + @register_relay_node class FuncType(Type): """A function type in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to functions in Relay. They consist of + a list of type parameters which enable the definition of generic + fucntions, a set of type constraints which we omit for the time + being, a sequence of argument types, and a return type. + + We informally write them as: + `forall (type_params), (arg_types) -> ret_type + where type_constraints` """ type_params: List[TypeParam] type_constraints: List[TypeConstraint] @@ -70,7 +122,23 @@ class FuncType(Type): span: Span def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[TypeParam], type_constraints: List[TypeConstraint]) -> None: - self.__init_handle_by_constructor__(_make.FuncType, arg_types, ret_type, type_params, type_constraints) + """Construct a function type. + + Parameters + ---------- + arg_types: list of Type + ret_type: Type + type_params: list of TypeParam + type_constraints: list of TypeConstraint + + Returns + ------- + func_type: FuncType + The function type. + """ + self.__init_handle_by_constructor__( + _make.FuncType, arg_types, ret_type, type_params, type_constraints) + @register_relay_node class TypeCall(Type): diff --git a/tutorials/relay/implement_fma_transform.py b/tutorials/relay/implement_fma_transform.py index 8410dd6c1152..8c04e70aa846 100644 --- a/tutorials/relay/implement_fma_transform.py +++ b/tutorials/relay/implement_fma_transform.py @@ -13,11 +13,11 @@ # Introduction # ------------------------- # -# We use the fixed size for input tensors with 256 channels and 14 x 14 -# dimensions. The batch size is 256. Convolution filters contain 512 filters -# of size 3 x 3. We use stride size 1 and padding size 1 for the -# convolution. The following code defines the convolution algorithm in TVM. -# +# In this tutorial, we will demonstrate how to write a fusion pass for +# the Relay IR. We demonstrate many Relay features including defining a +# new operator, a program transform, the NNVM compatibility layer, +# and executing the original and transformed programs on the Relay +# evaluator and TVM runtime system. from typing import Any, Dict From b75c223855d6df81d963e36d0ff38602a198e6d6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 17:11:20 -0700 Subject: [PATCH 75/88] Add skeleton for converting from NNVM models --- python/tvm/relay/from_nnvm.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 python/tvm/relay/from_nnvm.py diff --git a/python/tvm/relay/from_nnvm.py b/python/tvm/relay/from_nnvm.py new file mode 100644 index 000000000000..18a1112c2629 --- /dev/null +++ b/python/tvm/relay/from_nnvm.py @@ -0,0 +1,4 @@ +import nnvm + +def from_nnvm(graph): + import pdb; pdb.set_trace() From 92a8323faf6152e5c948e66787acbecae1d08924 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 17:27:08 -0700 Subject: [PATCH 76/88] Address more code review feedback --- include/tvm/relay/expr_functor.h | 2 +- include/tvm/relay/type.h | 5 ++--- python/tvm/relay/type.py | 4 ++-- src/relay/ir/op.cc | 10 +++------- src/relay/pass/unifier.cc | 2 +- 5 files changed, 9 insertions(+), 14 deletions(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 4632733cbcfc..0d736212c9eb 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -23,7 +23,7 @@ namespace relay { * \sa tvm/ir_functor.h * * \tparam FType function signiture - * This type if only defined for FType with function signiture R(const Expr&, + * This type is only defined for FType with function signature R(const Expr&, * Args...) */ template diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 5d579b661280..54cf91cee4ec 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -131,9 +131,8 @@ class TypeParamNode : public TypeNode { kType = 3, }; /*! - * \brief The variable - * The variable itself is only meaningful when - * kind is ShapeVar, otherwise, we can only use the name. + * \brief The variable itself is only meaningful when + * kind is ShapeVar, otherwise, we only use the name. */ tvm::Var var; /*! \brief The kind of type parameter */ diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index cde989603929..d9fc1eff1fd0 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -60,8 +60,8 @@ class Kind(IntEnum): """ ShapeVar = 0 Shape = 1 - BaseType = 1 - Type = 2 + BaseType = 2 + Type = 3 @register_relay_node diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 64467004a973..064551efe9d6 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -132,12 +132,8 @@ TVM_REGISTER_API("relay.op._Register") } }); -bool IsGeneric(const Op& op) { - if (auto ty_func = op.as()) { - return ty_func->type_params.size() != 0; - } else { - return false; - } +bool IsGeneric(const FuncType & func_ty) { + return func_ty->type_params.size() != 0; } using namespace runtime; @@ -151,7 +147,7 @@ Module CompileOpsToModule(const std::vector& op_names) { for (auto op_name : op_names) { Op op = Op::Get(op_name); - if (!IsGeneric(op)) { + if (!IsGeneric(op->op_type)) { auto compiler = compiler_map[op]; std::cout << "ABOVE CALL" << std::endl; tvm::Array pair = compiler(op->name, op->op_type); diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 4558f6a24919..2c809a574cc6 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -29,7 +29,7 @@ void UnionFindNode::insert(const IncompleteType &v) { this->uf_map.Set(v, v); } void UnionFindNode::debug() { for (auto entry : this->uf_map) { - std::cout << entry.first << " = " << entry.second << std::endl; + RELAY_LOG(INFO) << entry.first << " = " << entry.second << std::endl; } } From 998e10a4d969042329ec8028ba668c0917365b3f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 19:02:34 -0700 Subject: [PATCH 77/88] Fix cpplint --- include/tvm/relay/error.h | 12 ++-- include/tvm/relay/op.h | 111 +++++++++++++++--------------- include/tvm/relay/pass.h | 2 +- include/tvm/relay/pass/alpha_eq.h | 7 +- include/tvm/relay/source_map.h | 19 +++-- include/tvm/relay/type.h | 20 +++--- src/relay/ir/environment.cc | 3 +- src/relay/ir/op.cc | 19 +++-- src/relay/op/tensor/elemwise.cc | 2 +- src/relay/op/type_relations.cc | 10 +-- src/relay/op/type_relations.h | 8 +-- src/relay/pass/incomplete_type.h | 14 ++-- src/relay/pass/kind_check.cc | 2 +- src/relay/pass/resolve.h | 8 +-- src/relay/pass/type_infer.cc | 2 +- src/relay/pass/type_subst.h | 8 +-- src/relay/pass/type_visitor.h | 10 +-- src/relay/pass/unifier.cc | 71 ++++++++++--------- src/relay/pass/unifier.h | 11 +-- src/relay/source_map.cc | 21 +++--- 20 files changed, 193 insertions(+), 167 deletions(-) diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 433c08abfd58..055cc42936df 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -13,11 +13,11 @@ namespace tvm { namespace relay { struct Error : dmlc::Error { - Error(std::string msg) : dmlc::Error(msg) {} + explicit Error(const std::string &msg) : dmlc::Error(msg) {} }; struct InternalError : Error { - InternalError(std::string msg) : Error(msg) {} + explicit InternalError(const std::string &msg) : Error(msg) {} }; struct SpannedError { @@ -26,14 +26,14 @@ struct SpannedError { SpannedError(std::string msg, Span sp) : msg(msg), sp(sp) {} }; -// FIX, we should change spanned errors to have a method which allow them to report on the Environment, -// inverting control to error definition. +// FIX, we should change spanned errors to have a method which allow them to +// report on the Environment, inverting control to error definition. struct FatalTypeError : dmlc::Error { - explicit FatalTypeError(const std::string & s) : dmlc::Error(s) {} + explicit FatalTypeError(const std::string &s) : dmlc::Error(s) {} }; struct TypecheckerError : public dmlc::Error { - explicit TypecheckerError(const std::string &msg) : Error(msg) {} + explicit TypecheckerError(const std::string &msg) : Error(msg) {} }; } // namespace relay diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 756451e66768..2d5627f2c844 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -6,23 +6,23 @@ #ifndef TVM_RELAY_OP_H_ #define TVM_RELAY_OP_H_ +#include +#include #include -#include -#include #include -#include -#include +#include +#include +#include "../attrs.h" #include "./base.h" -#include "./type.h" #include "./expr.h" -#include "../attrs.h" +#include "./type.h" namespace tvm { namespace relay { // forward declare name. -template +template class OpMap; class GenericOpMap; class OpRegistry; @@ -103,7 +103,7 @@ class Op : public relay::Expr { * \return An OpMap of specified attr_name. * \tparam ValueType The type of the attribute. */ - template + template inline static OpMap GetAttr(const std::string& attr_name); /*! * \brief Get an Op for a given operator name. @@ -129,9 +129,7 @@ class Op : public relay::Expr { class OpRegistry { public: /*! \return the operator */ - const Op& op() const { - return op_; - } + const Op& op() const { return op_; } /*! * \brief setter function during registration * Set the description of operator @@ -146,24 +144,25 @@ class OpRegistry { * \param description Description of the argument. * \return reference to self. */ - inline OpRegistry& add_argument(const std::string &name, - const std::string &type, - const std::string &description); - /*! + inline OpRegistry& add_argument(const std::string& name, + const std::string& type, + const std::string& description); + /*! * \brief Attach the type function corresponding to the return type. * \param ty_func The type function to register for the return type. * \return reference to self. */ - inline OpRegistry& add_type_func(const std::string & type_func_name, TypeRelationFn type_fn); + inline OpRegistry& add_type_func(const std::string& type_func_name, + TypeRelationFn type_fn); - /*! + /*! * \brief Attach the type function corresponding to the return type. * \param ty_func The type function to register for the return type. * \return reference to self. */ inline OpRegistry& add_type_func( - const std::string & type_func_name, - std::function(const Array &, int)> type_fn); + const std::string& type_func_name, + std::function(const Array&, int)> type_fn); /*! * \brief Set the type key of attributes. @@ -196,10 +195,9 @@ class OpRegistry { * * \tparam ValueType The type of the value to be set. */ - template + template inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 10); + const ValueType& value, int plevel = 10); // set the name of the op to be the same as registry inline OpRegistry& set_name() { // NOLINT(*) @@ -222,8 +220,7 @@ class OpRegistry { // return internal pointer to op. inline OpNode* get(); // update the attribute OpMap - TVM_DLL void UpdateAttr(const std::string& key, - TVMRetValue value, + TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, int plevel); }; @@ -251,7 +248,7 @@ class GenericOpMap { * \return the const reference to the content value. * \tparam ValueType The content value type. */ - template + template inline ValueType get(const Op& op, ValueType def_value) const; private: @@ -268,7 +265,7 @@ class GenericOpMap { * \brief Map used to store meta-information about Op. * \tparam ValueType The type of the value stored in map. */ -template +template class OpMap { public: /*! @@ -294,15 +291,14 @@ class OpMap { private: friend class Op; // constructor - explicit OpMap(const GenericOpMap& map) - : map_(map) {} + explicit OpMap(const GenericOpMap& map) : map_(map) {} /*! \brief The internal map field */ const GenericOpMap& map_; }; // internal macros to make -#define RELAY_REGISTER_VAR_DEF \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry & __make_ ## RelayOp +#define RELAY_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp /*! * \def RELAY_REGISTER_OP @@ -319,16 +315,18 @@ class OpMap { * * \endcode */ -#define RELAY_REGISTER_OP(OpName) \ - DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ - ::tvm::relay::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name() +#define RELAY_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::relay::OpRegistry::Registry() \ + ->__REGISTER_OR_GET__(OpName) \ + .set_name() // implementations inline const OpNode* Op::operator->() const { return static_cast(node_.get()); } -template +template inline OpMap Op::GetAttr(const std::string& key) { return OpMap(Op::GetGenericAttr(key)); } @@ -337,14 +335,15 @@ inline OpNode* OpRegistry::get() { return const_cast(op_.operator->()); } -inline OpRegistry& OpRegistry::describe(const std::string& descr) { // NOLINT(*) +inline OpRegistry& OpRegistry::describe( + const std::string& descr) { // NOLINT(*) get()->description = descr; return *this; } -inline OpRegistry& OpRegistry::add_argument(const std::string &name, - const std::string &type, - const std::string &description) { +inline OpRegistry& OpRegistry::add_argument(const std::string& name, + const std::string& type, + const std::string& description) { std::shared_ptr n = std::make_shared(); n->name = name; n->type_info = type; @@ -354,13 +353,15 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, } inline OpRegistry& OpRegistry::add_type_func( - const std::string & type_func_name, - std::function(const Array &, int)> type_fn) { - auto pfunc = runtime::TypedPackedFunc(const Array &, int)>(type_fn); + const std::string& type_func_name, + std::function(const Array&, int)> type_fn) { + auto pfunc = + runtime::TypedPackedFunc(const Array&, int)>(type_fn); return add_type_func(type_func_name, pfunc); } -inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { +inline OpRegistry& OpRegistry::add_type_func(const std::string& type_func_name, + TypeRelationFn type_fn) { auto type_func = TypeRelationNode::make(type_func_name, 0, type_fn); std::vector type_params; @@ -397,7 +398,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } -inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) +inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) const std::string& type_key) { get()->attrs_type_key = type_key; return *this; @@ -408,13 +409,10 @@ inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) return *this; } -template +template inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) - const std::string& attr_name, - const ValueType& value, - int plevel) { - CHECK_GT(plevel, 0) - << "plevel in set_attr must be greater than 0"; + const std::string& attr_name, const ValueType& value, int plevel) { + CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; TVMRetValue rv; rv = value; UpdateAttr(attr_name, rv, plevel); @@ -435,12 +433,12 @@ inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { CHECK(op.defined()); const uint32_t idx = op->index_; CHECK(idx < data_.size() && data_[idx].second != 0) - << "Attribute " << attr_name_ - << " has not been registered for Operator " << op->name; + << "Attribute " << attr_name_ << " has not been registered for Operator " + << op->name; return data_[idx].first; } -template +template inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { CHECK(op.defined()); const uint32_t idx = op->index_; @@ -451,17 +449,18 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { } } -template +template inline int OpMap::count(const Op& op) const { return map_.count(op); } -template +template inline ValueType OpMap::operator[](const Op& op) const { return map_[op]; } -template -inline ValueType OpMap::get(const Op& op, ValueType def_value) const { +template +inline ValueType OpMap::get(const Op& op, + ValueType def_value) const { return map_.get(op, def_value); } diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index f92596c41179..46419bde3f97 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -47,4 +47,4 @@ bool KindCheck(const Environment& env, const Type& t); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_TYPECHECKER_H_ \ No newline at end of file +#endif // TVM_RELAY_PASS_H_ diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h index 51b5b4dd8b70..87b5164462d7 100644 --- a/include/tvm/relay/pass/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -3,8 +3,8 @@ * \file tvm/relay/alpha_eq.h * \brief Check expressions and types for structural equivalence. */ -#ifndef TVM_RELAY_ALPHA_EQ_H_ -#define TVM_RELAY_ALPHA_EQ_H_ +#ifndef TVM_RELAY_PASS_ALPHA_EQ_H_ +#define TVM_RELAY_PASS_ALPHA_EQ_H_ #include #include @@ -51,4 +51,5 @@ bool AlphaEqual(const Type& t1, const Type& t2); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_ALPHA_EQ_H_ +#endif // TVM_RELAY_PASS_ALPHA_EQ_H_ + diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h index a4dbc20b30ff..277c3875a17f 100644 --- a/include/tvm/relay/source_map.h +++ b/include/tvm/relay/source_map.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file source_map.h - * \brief A representation of source files and a data structure for + * \brief A representation of source files and a data structure for * storing them. */ #ifndef TVM_RELAY_SOURCE_MAP_H_ @@ -14,8 +14,15 @@ namespace tvm { namespace relay { +/*! \brief A fragment of a source file used for error reporting. + * + * These can be registered by the frontends and are used for + * displaying errors. + */ struct SourceFragment { + /*! \brief The file name which the source fragment originates from. */ std::string file_name; + /*! \brief The lines of source corresponding to the fragment. */ std::vector source_lines; SourceFragment(const std::string& file_name, const std::string& source); @@ -25,6 +32,7 @@ struct SourceFragment { this->source_lines = sf.source_lines; } + /*! \brief The lines of source code originate at lines. */ std::string SourceAt(Span sp, int lines); }; @@ -33,12 +41,15 @@ struct SourceFragment { class SourceMap { /*! \brief Map from unique token to a fragment of a source file. */ std::unordered_map map_; + public: SourceMap() : map_() {} - SourceName AddSource(std::string file_name, std::string source); - const SourceFragment & GetSource(SourceName id) const; + /*! \brief Add a source fragment with the file name and source. */ + SourceName AddSource(const std::string& file_name, const std::string& source); + /*! \brief Retrieve a source fragment by source name. */ + const SourceFragment& GetSource(SourceName id) const; }; } // namespace relay } // namespace tvm -#endif // TVM_RELAY_SOURCE_MAP_H_ \ No newline at end of file +#endif // TVM_RELAY_SOURCE_MAP_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 54cf91cee4ec..f485e0d8d62f 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -126,9 +126,9 @@ class TypeParamNode : public TypeNode { enum Kind : int { /*! \brief template variable in shape expression */ kShapeVar = 0, - kShape = 1, + kShape = 1, kBaseType = 2, - kType = 3, + kType = 3, }; /*! * \brief The variable itself is only meaningful when @@ -200,8 +200,8 @@ class FuncTypeNode : public TypeNode { } TVM_DLL static FuncType make(tvm::Array arg_types, Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints); + tvm::Array type_params, + tvm::Array type_constraints); static constexpr const char* _type_key = "relay.FuncType"; TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode); @@ -209,7 +209,8 @@ class FuncTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); -using TypeRelationFn = runtime::TypedPackedFunc(const Array&, int)>; +using TypeRelationFn = + runtime::TypedPackedFunc(const Array&, int)>; /*! * \brief Opaque type relation, is an input-output relation on types. @@ -238,7 +239,8 @@ class TypeRelationNode : public RelayNode { v->Visit("num_args", &num_args); } - TVM_DLL static TypeRelation make(std::string name, int num_args, TypeRelationFn func_); + TVM_DLL static TypeRelation make(std::string name, int num_args, + TypeRelationFn func_); static constexpr const char* _type_key = "relay.TypeRelation"; TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, RelayNode); @@ -257,7 +259,7 @@ class TypeCallNode : public TypeNode { public: /*! \brief The type function to be called. */ Type func; - + /*! \brief The type arguments to the type function. */ tvm::Array args; @@ -290,9 +292,7 @@ class TupleTypeNode : public TypeNode { TupleTypeNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("fields", &fields); - } + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } TVM_DLL static TupleType make(tvm::Array fields); diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index db7f11fb9e2b..b5f0d663d26a 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -149,7 +149,7 @@ void EnvironmentNode::Transform(EnvironmentNode::Transformer transformer) { to_process.push_back(var_and_func.first); } - auto for_each = transformer(GetRef(this)); + auto for_each = transformer(GetRef(this)); for (auto var : to_process) { auto func = this->functions[var]; auto transformed = for_each(var, func); @@ -157,7 +157,6 @@ void EnvironmentNode::Transform(EnvironmentNode::Transformer transformer) { } } - TVM_REGISTER_API("relay._make.Environment") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = EnvironmentNode::make(args[0]); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 064551efe9d6..18a647798c9e 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -1,12 +1,18 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/op.cc + * \brief Resolve incomplete types to complete types. + */ #include #include #include #include -#include "./../pass/type_subst.h" #include #include +#include "./../pass/type_subst.h" + namespace dmlc { // enable registry DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); @@ -132,7 +138,7 @@ TVM_REGISTER_API("relay.op._Register") } }); -bool IsGeneric(const FuncType & func_ty) { +bool IsGeneric(const FuncType& func_ty) { return func_ty->type_params.size() != 0; } @@ -181,12 +187,12 @@ TVM_REGISTER_API("relay.op._CompileOpsToModule") *ret = CompileOpsToModule(names); }); -Op SpecializeOp(const std::string& op_name, - const std::string& new_op_name, Array type_args) { +Op SpecializeOp(const std::string& op_name, const std::string& new_op_name, + Array type_args) { auto registry = ::tvm::relay::OpRegistry::Registry(); auto op_reg = registry->__REGISTER_OR_GET__(op_name); auto new_op_reg = registry->__REGISTER__(new_op_name).set_name(); - + auto fn_ty = op_reg.op()->op_type; tvm::Map subst_map; @@ -205,7 +211,8 @@ Op SpecializeOp(const std::string& op_name, new_op_reg.op()->op_type = new_op_ty; // Now we want to copy over some attributes. - PackedFunc compiler = Op::GetAttr("FRelayOpCompiler")[op_reg.op()]; + PackedFunc compiler = + Op::GetAttr("FRelayOpCompiler")[op_reg.op()]; new_op_reg.set_attr("FRelayOpCompiler", compiler); return new_op_reg.op(); diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 76adfbbfb968..d6a04773b7fa 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -120,5 +120,5 @@ RELAY_REGISTER_OP("equal") .set_support_level(1) .add_type_func("BroadcastComp", BroadcastCompRel); -} // namespace relayv +} // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 32d81a1d445e..2a6efbcf71e4 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -29,7 +29,7 @@ int to_int(const tvm::Expr& e) { } Array IdentityRel(const Array& types, int num_args) { - CHECK(types.size() == 2); + CHECK_EQ(types.size(), 2); auto t1 = as_ttype(types[0]); if (t1 && types[1].as()) { return {t1, t1}; @@ -88,7 +88,7 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, smaller = sh2; } - CHECK(larger.size() == smaller.size()); + CHECK_EQ(larger.size(), smaller.size()); Array out_shape; for (int i = 0; i < smaller.size(); i++) { @@ -105,11 +105,11 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, } Array BroadcastRel(const Array& types, int num_args) { - CHECK(types.size() == 3); + CHECK_EQ(types.size(), 3); if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { std::cout << t1->dtype << t2->dtype << std::endl; - CHECK(t1->dtype == t2->dtype); + CHECK_EQ(t1->dtype, t2->dtype); return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; } } @@ -121,7 +121,7 @@ Array BroadcastRel(const Array& types, int num_args) { compute boolean results. */ Array BroadcastCompRel(const Array& types, int num_args) { - CHECK(types.size() == 3); + CHECK_EQ(types.size(), 3); if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())}; diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 71c98fef7da1..3597246b5a4a 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -4,11 +4,11 @@ * \brief A set of utilities and common functionality * for type relations. */ -#ifndef TVM_RELAY_TYPECK_RESOLVE_H_ -#define TVM_RELAY_TYPECK_RESOLVE_H_ +#ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_ +#define TVM_RELAY_OP_TYPE_RELATIONS_H_ -#include #include +#include namespace tvm { namespace relay { @@ -20,4 +20,4 @@ Array BroadcastCompRel(const Array & types, int num_args); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPECK_RESOLVE_H_ +#endif // TVM_RELAY_OP_TYPE_RELATIONS_H_ diff --git a/src/relay/pass/incomplete_type.h b/src/relay/pass/incomplete_type.h index 3967b4e58657..78771dc6e9b7 100644 --- a/src/relay/pass/incomplete_type.h +++ b/src/relay/pass/incomplete_type.h @@ -4,8 +4,8 @@ * \brief A way to defined arbitrary function signature with dispatch on types. */ -#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H -#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H +#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ +#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ #include @@ -22,9 +22,7 @@ class IncompleteTypeNode : public TypeNode { public: TypeParamNode::Kind kind; - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("kind", &kind); - } + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); } TVM_DLL static IncompleteType make(TypeParamNode::Kind kind); @@ -34,7 +32,7 @@ class IncompleteTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); -} // namespace relay -} // namespace tvm +} // namespace relay +} // namespace tvm -#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H +#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index c3823c8c3a35..522eb93483fb 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -39,4 +39,4 @@ bool KindCheck(const Environment& env, const Type &t) { } } // namespace relay -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/relay/pass/resolve.h b/src/relay/pass/resolve.h index 495c9658238a..deb6558322b8 100644 --- a/src/relay/pass/resolve.h +++ b/src/relay/pass/resolve.h @@ -3,11 +3,11 @@ * \file tvm/relay/resolve.h * \brief Resolve incomplete types to complete types. */ -#ifndef TVM_RELAY_TYPECK_RESOLVE_H_ -#define TVM_RELAY_TYPECK_RESOLVE_H_ +#ifndef TVM_RELAY_PASS_RESOLVE_H_ +#define TVM_RELAY_PASS_RESOLVE_H_ -#include #include +#include #include "./unifier.h" namespace tvm { @@ -20,4 +20,4 @@ bool IsFullyResolved(const Type & t); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPECK_RESOLVE_H_ +#endif // TVM_RELAY_PASS_RESOLVE_H_ diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 6cc73d1b8fbe..df896fa3936a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -60,7 +60,7 @@ struct TypeContext { struct TypeNormalizer : TypeFVisitor { TypeUnifier unifier; - TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} + explicit TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} Type VisitType_(const TypeCallNode *ty_call_node) { auto ty_call = GetRef(ty_call_node); diff --git a/src/relay/pass/type_subst.h b/src/relay/pass/type_subst.h index 3c248fdce3b7..5b6956f8e451 100644 --- a/src/relay/pass/type_subst.h +++ b/src/relay/pass/type_subst.h @@ -1,10 +1,10 @@ /*! * Copyright (c) 2018 by Contributors - * \file typeck/type_subst.h + * \file src/tvm/relay/pass/type_subst.h * \brief Utility functions for substituting types. */ -#ifndef TVM_RELAY_TYPECK_TYPE_SUBST_H_ -#define TVM_RELAY_TYPECK_TYPE_SUBST_H_ +#ifndef TVM_RELAY_PASS_TYPE_SUBST_H_ +#define TVM_RELAY_PASS_TYPE_SUBST_H_ #include @@ -16,4 +16,4 @@ Type TypeSubst(const Type &type, tvm::Map subst_map); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPECK_TYPE_SUBST_H_ +#endif // TVM_RELAY_PASS_TYPE_SUBST_H_ diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index f3c0f9a74fb7..d65d6c567b23 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -3,8 +3,8 @@ * \file type_visitor.h * \brief A wrapper around TypeFunctor for common use cases. */ -#ifndef TVM_RELAY_TYPE_VISITOR_H_ -#define TVM_RELAY_TYPE_VISITOR_H_ +#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_ +#define TVM_RELAY_PASS_TYPE_VISITOR_H_ #include #include "./type_functor.h" @@ -54,7 +54,7 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { // A functional visitor for rebuilding an AST in place. struct TypeFVisitor : TypeFunctor { Type VisitType_(const TensorTypeNode* op) override { - // TODO (@jroesch): maybe we should recursively visit + // TODO(@jroesch): maybe we should recursively visit return TensorTypeNode::make(op->shape, op->dtype); } @@ -63,7 +63,7 @@ struct TypeFVisitor : TypeFunctor { } Type VisitType_(const FuncTypeNode* op) override { - // TODO (@jroesch): handle poly + // TODO(@jroesch): handle poly // auto new_id = this->VisitType(op->var); // if (const TypeParamNode* tin = new_id.as()) { @@ -107,4 +107,4 @@ struct TypeFVisitor : TypeFunctor { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPE_VISITOR_H_ +#endif // TVM_RELAY_PASS_TYPE_VISITOR_H_ diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 2c809a574cc6..7735ca8b0482 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -5,11 +5,11 @@ * incomplete types. */ -#include +#include "./unifier.h" #include #include #include -#include "./unifier.h" +#include #include "./type_visitor.h" // #include "tvm/relay/typeck/kindchecker.h" @@ -33,7 +33,7 @@ void UnionFindNode::debug() { } } -void UnionFindNode::AssertAlphaEqual(const Type & l, const Type & r) { +void UnionFindNode::AssertAlphaEqual(const Type &l, const Type &r) { if (!AlphaEqual(l, r)) { std::stringstream ss; ss << "Incompatible parent types in UF:" << l << " and " << r; @@ -141,7 +141,7 @@ Type TypeUnifierNode::unify(const Type &t1, const Type &t2) { Type unified = this->VisitType(t1, t2); // if (!check_kind(unified)) { - // throw UnificationError("Invalid kinds in unified type"); + // throw UnificationError("Invalid kinds in unified type"); // } return unified; } @@ -167,32 +167,34 @@ Type TypeUnifierNode::subst(const Type &t) { // normalize first so substitutions in quantifiers will be correct Type ret = tvsubst.VisitType(t); // if (!check_kind(ret)) { - // std::stringstream ss; - // ss << "Invalid Kinds in substituted type!"; - // ss << t << std::endl; - // ss << ret << std::endl; - // throw SubstitutionError(ss.str()); + // std::stringstream ss; + // ss << "Invalid Kinds in substituted type!"; + // ss << t << std::endl; + // ss << ret << std::endl; + // throw SubstitutionError(ss.str()); // } return ret; } -Type TypeUnifierNode::VisitType(const Type & t1, const Type t2) { +Type TypeUnifierNode::VisitType(const Type &t1, const Type t2) { // When the right hand size is a type variable immediately unify. if (const IncompleteTypeNode *tvn2 = t2.as()) { return this->unifyWithIncompleteType(t1, GetRef(tvn2)); - // The TypeCallNode case is special and not symmetric. - // - // We flip the arguments so we hit the TypeCall and other case in there is - // ever a type call. + // The TypeCallNode case is special and not symmetric. + // + // We flip the arguments so we hit the TypeCall and other case in there is + // ever a type call. } else if (const TypeCallNode *tvn2 = t2.as()) { - return TypeFunctor::VisitType(t2, t1); + return TypeFunctor::VisitType(t2, t1); } else { - return TypeFunctor::VisitType(t1, t2); + return TypeFunctor::VisitType(t1, t2); } } -Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, const IncompleteType tv2) { - RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; +Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, + const IncompleteType tv2) { + RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 + << std::endl; // Fix unify to return new representative this->uf->unify(tv2, t1); auto rep = this->uf->find(tv2); @@ -235,7 +237,8 @@ Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { FuncType ft2 = GetRef(tan2); if (ft1->type_params.size() != ft2->type_params.size()) { - throw UnificationError("unable to unify functions with differing number of type parameters"); + throw UnificationError( + "unable to unify functions with differing number of type parameters"); } if (ft1->type_params.size() != 0) { @@ -282,7 +285,7 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { TensorType tt2 = GetRef(ttn2); if (!AlphaEqual(tt1, tt2)) { - throw UnificationError("dtypes do not match"); + throw UnificationError("dtypes do not match"); } RELAY_LOG(INFO) << "Unify Tensor Shape s1=" << tt1->shape @@ -290,8 +293,9 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { try { // Type unified_shape = this->VisitType(tt1->shape, tt2->shape); return rt2; - } catch (const UnificationError & err) { - std::cout << "Need to check constraint " << tt1->shape << " = " << tt2->shape << std::endl; + } catch (const UnificationError &err) { + std::cout << "Need to check constraint " << tt1->shape << " = " + << tt2->shape << std::endl; } // fix me @@ -328,15 +332,16 @@ Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { } Type TypeUnifierNode::VisitType_(const TypeRelationNode *tr1, const Type t2) { - if (const TypeRelationNode *tr2 = t2.as()) { - if (tr1 == tr2) { - return GetRef(tr1); - } else { - throw UnificationError("Cannot unify different type relations"); - } - } else { - throw UnificationError("Cannot unify type relation with another type of type"); - } + if (const TypeRelationNode *tr2 = t2.as()) { + if (tr1 == tr2) { + return GetRef(tr1); + } else { + throw UnificationError("Cannot unify different type relations"); + } + } else { + throw UnificationError( + "Cannot unify type relation with another type of type"); + } } Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { @@ -347,7 +352,8 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { // For now, we will only unify if they are equal. if (ty_call1->args.size() != tcn2->args.size()) { - throw UnificationError("Cannot unify calls of different number of arguments"); + throw UnificationError( + "Cannot unify calls of different number of arguments"); } // Unify members, if possible @@ -364,6 +370,5 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { } } - } // namespace relay } // namespace tvm diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index 64485768c2f0..0671a40c0d74 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -7,8 +7,8 @@ #ifndef TVM_RELAY_PASS_UNIFIER_H_ #define TVM_RELAY_PASS_UNIFIER_H_ -#include #include +#include #include "./type_functor.h" namespace tvm { @@ -62,7 +62,7 @@ class UnionFind : public NodeRef { explicit UnionFind(std::shared_ptr p) : NodeRef(p) {} // The union find structure is mutable so we do not use the standard macros - // and expose the pointer via `->`. + // and expose the pointer via `->`. UnionFindNode* operator->() const { return static_cast(node_.get()); } @@ -102,8 +102,9 @@ class TypeUnifierNode : public Node, private: /*! \brief Unify incomplete type with another type. */ Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); - /*! \brief Implements unification between two types with incomplete portions. */ - Type VisitType(const Type & t1, const Type t2) override; + /*! \brief Implements unification between two types with incomplete portions. + */ + Type VisitType(const Type& t1, const Type t2) override; // Visitor Cases Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; @@ -130,4 +131,4 @@ class TypeUnifier : public NodeRef { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPECK_UNIFIER_H_ +#endif // TVM_RELAY_PASS_UNIFIER_H_ diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc index d784c7946954..9d3316cf38cf 100644 --- a/src/relay/source_map.cc +++ b/src/relay/source_map.cc @@ -14,15 +14,18 @@ namespace relay { using tvm::IRPrinter; using namespace tvm::runtime; -SourceFragment::SourceFragment(const std::string& file_name, const std::string& source) +SourceFragment::SourceFragment(const std::string& file_name, + const std::string& source) : file_name(file_name), source_lines({}) { - RELAY_LOG(INFO)<< "SourceFragment::SourceFragment source=" << source << std::endl; + RELAY_LOG(INFO) << "SourceFragment::SourceFragment source=" << source + << std::endl; std::stringstream source_stream; source_stream.str(source.c_str()); std::string line; while (std::getline(source_stream, line)) { - RELAY_LOG(INFO) << "SourceFragment::SourceFragment: line=" << line << std::endl; + RELAY_LOG(INFO) << "SourceFragment::SourceFragment: line=" << line + << std::endl; std::string copy(line); source_lines.push_back(copy); } @@ -38,7 +41,8 @@ std::string SourceFragment::SourceAt(Span sp, int max_lines = 1) { throw dmlc::Error("SourceFragment: index out of bounds"); } - auto lines = std::max(static_cast(max_lines), source_lines.size() - starting_line); + auto lines = std::max(static_cast(max_lines), + source_lines.size() - starting_line); for (size_t i = 0; i < lines; i++) { out << std::endl << this->source_lines.at(starting_line + i); @@ -46,11 +50,12 @@ std::string SourceFragment::SourceAt(Span sp, int max_lines = 1) { auto source_slice = out.str(); - RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice << std::endl; + RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice + << std::endl; return source_slice; } -SourceName SourceMap::AddSource(std::string file_name, std::string source) { +SourceName SourceMap::AddSource(const std::string & file_name, const std::string & source) { auto new_id = SourceNameNode::make(file_name); SourceFragment sfile(file_name, source); this->map_.insert({new_id, sfile}); @@ -62,9 +67,9 @@ const SourceFragment& SourceMap::GetSource(SourceName id) const { if (item != map_.end()) { return (*item).second; } else { - throw dmlc::Error("could not find requested source fragment"); + throw dmlc::Error("could not find requested source fragment"); } } } // namespace relay -} // namespace tvm \ No newline at end of file +} // namespace tvm From 3a98c3df2a97756873d985af5497161e521ad5db Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 19:11:31 -0700 Subject: [PATCH 78/88] Fix pylint --- nnvm/python/nnvm/_base.py | 7 +- python/tvm/relay/__init__.py | 10 +-- python/tvm/relay/base.py | 1 - python/tvm/relay/env.py | 14 ++- python/tvm/relay/expr.py | 65 ++++++++++---- python/tvm/relay/from_nnvm.py | 5 +- python/tvm/relay/ir_builder.py | 87 ++++++++++++------- python/tvm/relay/ir_pass.py | 23 +++-- python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/_tensor.py | 20 +++-- python/tvm/relay/op/op.py | 19 ++-- python/tvm/relay/op/tensor.py | 5 +- python/tvm/relay/to_tvm.py | 3 +- python/tvm/relay/type.py | 34 ++++---- python/tvm/tensor.py | 5 ++ .../relay/test_tyck_eval_integration.py | 2 +- 16 files changed, 191 insertions(+), 110 deletions(-) diff --git a/nnvm/python/nnvm/_base.py b/nnvm/python/nnvm/_base.py index 63b2f815ad9b..29390a2201bf 100644 --- a/nnvm/python/nnvm/_base.py +++ b/nnvm/python/nnvm/_base.py @@ -22,12 +22,7 @@ numeric_types = (float, int, np.float32, np.int32) # this function is needed for python3 # to convert ctypes.char_p .value back to python str - def py_str(x): - try: - return x.decode('utf-8') - except: - print(x) - # py_str = lambda x: x.decode('utf-8') + py_str = lambda x: x.decode('utf-8') else: string_types = basestring numeric_types = (float, int, long, np.float32, np.int32) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c36b9bcf8357..aae019c8d9c1 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,8 +1,12 @@ +# pylint: disable=wildcard-import """The Relay IR namespace containing the IR definition and compiler.""" from . import base from . import type as tpe from . import expr -from . import op + +# Operators +from .op import Op +from .op.tensor import * # Span Span = base.Span @@ -26,7 +30,3 @@ Let = expr.Let If = expr.If Var = LocalVar - -# Operators -from .op import Op -from .op.tensor import * diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index ee818617f629..0f3d2bc58d71 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -1,7 +1,6 @@ # pylint: disable=no-else-return, unidiomatic-typecheck """The base node types for the Relay language.""" from __future__ import absolute_import as _abs -from typing import Union from .._ffi.node import NodeBase, register_node as _register_tvm_node from . import _make diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 86c9ac794b4e..beef6fd1a62c 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -1,28 +1,26 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import """A global environment storing everything needed to interpret or compile a Realy program.""" -from typing import Union, List from .base import register_relay_node, NodeBase from . import _make from . import _env -import tvm @register_relay_node class Environment(NodeBase): - """The global Relay environment containing definitions, - primitives, options, and more. + """The global Relay environment containing functions, + options and more. """ def __init__(self, funcs) -> None: self.__init_handle_by_constructor__(_make.Environment, funcs) - + def add(self, var, func) -> None: if isinstance(var, str): var = _env.Environment_GetGlobalVar(self, var) _env.Environment_Add(self, var, func) - + def merge(self, other): return _env.Environment_Merge(self, other) - + def global_var(self, var): return _env.Environment_GetGlobalVar(self, var) @@ -31,6 +29,6 @@ def lookup(self, var): return _env.Environment_Lookup_str(self, var) else: return _env.Environment_Lookup(self, var) - + def transform(self, transformer): _env.Environment_Transform(self, transformer) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 1558853c2820..748b2aa1e282 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -1,20 +1,22 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression nodes of Relay.""" +from typing import List import tvm -from typing import Tuple as PyTuple, List -from enum import IntEnum from .base import Span, NodeBase, register_relay_node from .type import Type, TypeParam -from tvm import expr from ._ir_pass import _get_checked_type from . import _make + class ExprBuilder(): - # def convert_args(self, + """A set of methods useful for building expressions + from other expressions. + """ def __call__(self, *args, **kwargs): converted_args = [] for arg in args: - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() if isinstance(arg, Param): converted_args.append(arg.var) else: @@ -22,11 +24,14 @@ def __call__(self, *args, **kwargs): return Call(self, args, None, None) + class Expr(NodeBase, ExprBuilder): """The base type for all Relay exprressions.""" + def checked_type(self): return _get_checked_type(self) + @register_relay_node class Constant(Expr): """A constant tensor in Relay, see tvm/relay/type.h for more details. @@ -36,6 +41,7 @@ class Constant(Expr): def __init__(self, data: tvm.nd.NDArray) -> None: self.__init_handle_by_constructor__(_make.Constant, data) + @register_relay_node class Tuple(Expr): """A hetereogenous sequence of values. @@ -55,6 +61,7 @@ class LocalVar(Expr): def __init__(self, name_hint: str) -> None: self.__init_handle_by_constructor__(_make.LocalVar, name_hint) + @register_relay_node class GlobalVar(Expr): """A global variable in Relay.""" @@ -63,6 +70,7 @@ class GlobalVar(Expr): def __init__(self, name_hint: str) -> None: self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) + @register_relay_node class Param(Expr): """A function type in Relay, see tvm/relay/type.h for more details. @@ -70,47 +78,66 @@ class Param(Expr): var: LocalVar type: Type - def __init__(self, var: LocalVar, type: Type) -> None: - self.__init_handle_by_constructor__(_make.Param, var, type) + def __init__(self, var: LocalVar, ty: Type) -> None: + self.__init_handle_by_constructor__(_make.Param, var, ty) @register_relay_node class Function(Expr): + """A function in Relay, see tvm/relay/expr.h for more details.""" type_params: List[TypeParam] params: List[Param] ret_type: Type body: Expr - def __init__(self, params: List[Param], ret_type: Type, body: Expr, type_params: List[TypeParam]=[]) -> None: - self.__init_handle_by_constructor__(_make.Function, params, ret_type, body, type_params) + def __init__(self, + params: List[Param], + ret_type: Type, + body: Expr, + type_params: List[TypeParam] = None) -> None: + if not type_params: + type_params = [] + self.__init_handle_by_constructor__( + _make.Function, params, ret_type, body, type_params) + @register_relay_node class Call(Expr): - op: Expr - args: List[Expr] - # todo(@jroesch): add attrs + """A function call in Relay, see tvm/relay/expr.h for more details.""" + op: Expr + args: List[Expr] + # todo(@jroesch): add attrs + + def __init__(self, op: Expr, args: List[Expr], attrs, ty_args=None) -> None: + if not ty_args: + ty_args = [] + + self.__init_handle_by_constructor__( + _make.Call, op, args, attrs, ty_args) - def __init__(self, op: Expr, args: List[Expr], attrs, ty_args) -> None: - self.__init_handle_by_constructor__(_make.Call, op, args, attrs, ty_args) @register_relay_node class Let(Expr): + """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" var: LocalVar value: Expr body: Expr - value_type: Type # should be type nanotation + # should be type annotation + value_type: Type def __init__(self, var: LocalVar, value: Expr, body: Expr, value_type: Type) -> None: - self.__init_handle_by_constructor__(_make.Let, var, value, body, value_type) + self.__init_handle_by_constructor__( + _make.Let, var, value, body, value_type) + @register_relay_node class If(Expr): + """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" cond: Expr true_value: Expr false_value: Expr span: Span def __init__(self, cond: Expr, true_value: Expr, false_value: Expr) -> None: - self.__init_handle_by_constructor__(_make.If, cond, true_value, false_value) - - + self.__init_handle_by_constructor__( + _make.If, cond, true_value, false_value) diff --git a/python/tvm/relay/from_nnvm.py b/python/tvm/relay/from_nnvm.py index 18a1112c2629..9700ea955f59 100644 --- a/python/tvm/relay/from_nnvm.py +++ b/python/tvm/relay/from_nnvm.py @@ -1,4 +1,7 @@ +#pylint: disable-all +"""Convert an nnvm.graph.Graph into a tvm.relay.Expr""" import nnvm def from_nnvm(graph): - import pdb; pdb.set_trace() + """Convert an nnvm.graph.Graph into a tvm.relay.Expr""" + raise Exception("NYI") diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 098eb474c6ee..a271a537b290 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,10 +1,14 @@ +"""IR builder for the Relay IR. + +Enables users to construct Relay programs with a Python API. +""" from typing import Any import numpy as np import tvm from .type import FuncType, TensorType -from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function, If +from .expr import Expr, Constant, Let, LocalVar, Param, Function, If from .env import Environment -from . import op as _op + def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: """Convert Python values into the appropriate types @@ -24,6 +28,7 @@ def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: # raise Exception(f"can't convert {type(arg)} to a Relay AST") raise Exception(f"unsupported argument type {type(arg)}") + def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: if isinstance(arg, Expr): return arg @@ -35,6 +40,7 @@ def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: value = convert(arg, ctxt) return Constant(value) + class WithScope(object): """A wrapper for builder methods which introduce scoping.""" @@ -53,6 +59,7 @@ def __exit__(self, ptype, value, trace): class PartialFunc(): + """A wrapper around functions while they are being built.""" def __init__(self, params, ret_type, body, type_params): self.params = params self.ret_type = ret_type @@ -69,15 +76,20 @@ def to_func(self): self.body, self.type_params) - +#pylint: disable=invalid-name def _mk_let(bindings, ret_value): let_expr = ret_value - for var, value in reversed(list(bindings.items())): - let_expr = Let(var, value, let_expr, None) + for var, value, ty in reversed(list(bindings.items())): + let_expr = Let(var, value, let_expr, ty) return let_expr + class IRBuilder(): + """The IRBuilder class. + + Enables users to build up a Relay environment and program. + """ def __init__(self): self.bindings = [{}] self.scopes = [{}] @@ -85,13 +97,15 @@ def __init__(self): self.ret_values = [None] self.env = Environment({}) - def enter_scope(self, params=[]): + def enter_scope(self, params=None): + if not params: + params = [] + self.bindings.append({}) self.scopes.append({}) self.params.append(params) self.ret_values.append(None) - def exit_scope(self): bindings = self.bindings.pop() scopes = self.scopes.pop() @@ -99,14 +113,13 @@ def exit_scope(self): ret_value = self.ret_values.pop() return bindings, scopes, params, ret_value - - def bind(self, name, type, value): + #pylint: disable=invalid-name + def bind(self, name, ty, value): lv = LocalVar(name) self.scopes[-1][name] = lv - self.bindings[-1][lv] = value + self.bindings[-1][lv] = (value, ty) return lv - def let(self, name, value, value_type=None): if isinstance(value, Param): value = value.var @@ -117,6 +130,7 @@ def let(self, name, value, value_type=None): return self.bind(name, value_type, value) def function(self, *params): + """Construct a Relay function.""" relay_params = [] for param in params: name = param.var @@ -131,13 +145,12 @@ def function(self, *params): pfunc = PartialFunc(relay_params, None, None, []) def _on_exit(): - bindings, scope, params, ret_value = self.exit_scope() + bindings, _, _, ret_value = self.exit_scope() body = _mk_let(bindings, ret_value) pfunc.body = body return WithScope(pfunc, _on_exit) - def ret(self, x): if not self.ret_values[-1]: self.ret_values[-1] = into_ast(x) @@ -146,6 +159,7 @@ def ret(self, x): "return value already set, a function can only have one return value") def if_scope(self, cond): + """Construct the if branch an if expression with scoping.""" self.enter_scope() def _on_exit(): @@ -153,29 +167,30 @@ def _on_exit(): assert self.ret_values[-1] is None true_branch = _mk_let(bindings, ret_value) self.ret_values[-1] = If(cond, true_branch, None) - + return WithScope(10, _on_exit) - def else_scope(self): + """Construct the else branch of an if expression with scoping.""" self.enter_scope() def _on_exit(): bindings, _, _, ret_value = self.exit_scope() partial_if = self.ret_values[-1] - assert isinstance(partial_if, If) and partial_if.false_value is None + assert isinstance( + partial_if, If) and partial_if.false_value is None false_branch = _mk_let(bindings, ret_value) self.ret_values[-1] = If( - partial_if.cond, - partial_if.true_value, + partial_if.cond, + partial_if.true_value, false_branch) - + return WithScope(10, _on_exit) def param(self, name, ty=None): if not ty: ty = float_type() - + return Param(LocalVar(name), ty) # def params(*args): @@ -183,7 +198,7 @@ def param(self, name, ty=None): # while i < args.size(): # arg = args[i] # if isinstance(arg, str): - + def global_var(self, name: str): return self.env.global_var(name) @@ -197,8 +212,8 @@ def _on_exit(): return WithScope(10, _on_exit) - # def while_loop(cond) + def get(self): """Get the full program""" bindings = self.bindings.pop() @@ -215,33 +230,47 @@ def get(self): return _mk_let(bindings, self.ret_values[-1]), self.env + def bool_dtype(): return 'uint1' + def int_dtype(bits=32): return f'int{bits}' + def float_dtype(bits=32): return f'float{bits}' + def uint_dtype(bits=32): return f'uint{bits}' - -def int_type(bits=32, lanes=1): + + +def int_type(bits=32, _lanes=1): # TODO(@jroesch, @tqchen) How do we set lanes? return TensorType(tvm.convert([]), int_dtype(bits)) -def uint_type(bits=32, lanes=1): + +def uint_type(bits=32, _lanes=1): return TensorType(tvm.convert([]), uint_dtype(bits)) -def float_type(bits=32, lanes=1): + +def float_type(bits=32, _lanes=1): return TensorType(tvm.convert([]), float_dtype(bits)) -def bool_type(lanes=1): - return TensorType(tvm.convert([]), bool_dtype(bits)) + +def bool_type(_lanes=1): + return TensorType(tvm.convert([]), bool_dtype()) + def tensor_type(*shape, dtype='float32'): return TensorType(tvm.convert(shape), dtype) -def func_type(args, ret_type, type_params=[], type_constraints=[]): + +def func_type(args, ret_type, type_params=None, type_constraints=None): + if not type_params: + type_params = [] + if not type_constraints: + type_constraints = [] return FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 8b49710f70ec..b075704c212a 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -1,7 +1,6 @@ -# pylint: disable=no-else-return +# pylint: disable=no-else-return, # pylint: disable=unidiomatic-typecheck -""" -The optimizer for Relay. +"""The optimizer for Relay. Exposes an interface for configuring the optimizer and scripting it directly in Python. @@ -26,6 +25,7 @@ from . import _ir_pass # Expose checking expression, should rename to infer_type. +# pylint: disable=invalid-name check_expr = _ir_pass.check_expr # # pylint: disable=invalid-name @@ -47,7 +47,10 @@ def mangle(name: str, types: List[Type]) -> str: name += str(typ) + "_" return name + T = TypeVar('T') + + class AbstractExprVisitor(Generic[T]): """A functional visitor over Expr in Python.""" @@ -104,11 +107,13 @@ def visit_global_var(self, _: GlobalVar) -> T: def to_pass(cls) -> Callable[[Environment], Callable[[GlobalVar, Function], Function]]: def _outer_wrapper(env): visitor = cls(env) - def _inner_wrapper(var, func): + + def _inner_wrapper(_, func): return visitor.visit(func) return _inner_wrapper return _outer_wrapper + class ExprVisitor(AbstractExprVisitor[Expr]): """A functional visitor over Expr in Python.""" @@ -149,8 +154,10 @@ def visit_tuple(self, tup: Tuple) -> Expr: def visit_constant(self, const: Constant) -> Expr: return const + MMCacheKey = Tuple[Union[GlobalVar, str], List[Type]] + class Monomorphize(ExprVisitor): """A monomorphization pass. @@ -182,11 +189,12 @@ def visit_call(self, call: Call) -> Expr: mono_name = mangle(poly_name, call.type_args) for arg in call.type_args: if isinstance(arg, TypeParam): - return call # raise Exception("...") # Fix me in the morning!!! + # raise Exception("...") # Fix me in the morning!!! + return call mono_op = specialize_op(poly_name, mono_name, call.type_args) self.monomorph_map[cache_key] = mono_op - return Call(mono_op, new_args,call.attrs, []) + return Call(mono_op, new_args, call.attrs, []) elif isinstance(call.op, GlobalVar): return call # defn = self.env.lookup(call.op) @@ -203,7 +211,7 @@ def visit_call(self, call: Call) -> Expr: # self.env.add(defn) # self.visit_item(defn) # return Call(new_id, call.args, call.attrs) - + elif isinstance(call.op, Function): return call # new_func = type_specialize(call.type_args, call.op) @@ -222,4 +230,3 @@ def visit_call(self, call: Call) -> Expr: # TODO(@jroesch): Fix up my type __tgt_host__ = __tgt__ = "llvm" __relay_tvm_context__ = tvm.cpu() - diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 47ebc5501cab..5c3a8ac249a6 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,3 +1,4 @@ +#pylint: disable=wildcard-import """Relay core operators.""" # operator defs from .op import get, register, Op, compile_ops diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index da94ec89b380..4427faa6a3a6 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -1,17 +1,20 @@ +#pylint: disable=invalid-name """Backend compiler related feature regsitration""" +from topi import add from .op import register from ..type import FuncType, TensorType from ...schedule import create_schedule from ...api import placeholder -from topi import add def type_to_placeholder(name, ty): + """Convert a single type into the correct placeholder.""" if isinstance(ty, TensorType): return placeholder(ty.shape, name=name, dtype=ty.dtype) else: raise Exception("can only pass Tensor values to TVM operators") def func_ty_to_placeholders(func_ty): + """Build input placeholders based on a function type.""" if isinstance(func_ty, FuncType): arg_types = func_ty.arg_types ret_type = func_ty.ret_type @@ -45,12 +48,13 @@ def func_ty_to_placeholders(func_ty): # schedule = tvm.create_schedule(Output.op) # return [schedule, Inputs + [Output]] - -def add_compiler(op_name, func_type, *args): - Inputs, ret_ty = func_ty_to_placeholders(func_type) +#pylint: disable=duplicate-argument-name +def add_compiler(_, func_type, *_): + """The compilation code for the TVM compiler.""" + inputs, _ = func_ty_to_placeholders(func_type) # op = lookup_in_topi(op_name) - Output = add(*Inputs) - schedule = create_schedule(Output.op) - return [schedule, Inputs + [Output]] + output = add(*inputs) + schedule = create_schedule(output.op) + return [schedule, inputs + [output]] -register("add", "FRelayOpCompiler", add_compiler) \ No newline at end of file +register("add", "FRelayOpCompiler", add_compiler) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index bb589f40f138..14570b62269b 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -3,13 +3,13 @@ from ..base import register_relay_node from ..expr import Expr -from ..._ffi.function import Function, register_func -from ...api import convert -from ...container import Map -from ... import lower, build, cpu +from ..._ffi.function import register_func +from ... import lower, build + @register_relay_node class Op(Expr): + """A Relay operator definition.""" def __init__(self): raise RuntimeError("Cannot create op, use get instead") @@ -74,6 +74,7 @@ def _register(v): return v return _register(value) if value else _register + def compile_ops(op_names): """Register an operator property of an operator. @@ -90,6 +91,8 @@ def compile_ops(op_names): return _CompileOpsToModule(*op_names) # TODO(@jroesch): We should port to C++, just need to figure out how to write this code. + + @register_func("relay.op._compile_ops") def _compile_ops(op_impls): lowered = [] @@ -100,8 +103,10 @@ def _compile_ops(op_impls): # TOOD(@jroesch): Where should we read these settings from return build(lowered, target='llvm', target_host='llvm') + _init_api("relay.op", __name__) + def specialize_op(op_name, new_op_name, type_args): """Specializes an operator to a set of types and assigns it new_op_name. @@ -110,7 +115,7 @@ def specialize_op(op_name, new_op_name, type_args): add : forall (T : Type) (U : Type), (U, T) -> Broadcast(U, T) - This is a function which is polymorphic over two types `T` and `U` and + This is a function which is polymorphic over two types `T` and `U` and takes a value of type `T` and one of `U` and returns `Broadcast` of U and T. @@ -135,9 +140,9 @@ def specialize_op(op_name, new_op_name, type_args): ---------- op_name : str The operator to be specialized. - + Returns ------- The specialized operator. """ - return _SpecializeOp(op_name, new_op_name, type_args) \ No newline at end of file + return _SpecializeOp(op_name, new_op_name, type_args) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index d0c1b88eb240..57fbccf488dc 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -59,6 +59,7 @@ def sqrt(data): """ return _make.sqrt(data) + def add(lhs, rhs): """Take sqrt of data. @@ -76,6 +77,7 @@ def add(lhs, rhs): """ return _make.add(lhs, rhs) + def subtract(lhs, rhs): """Take sqrt of data. @@ -93,5 +95,6 @@ def subtract(lhs, rhs): """ return _make.add(lhs, rhs) + def equal(lhs, rhs): - return _make.equal(lhs, rhs) \ No newline at end of file + return _make.equal(lhs, rhs) diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py index 181251844a6d..615a39301142 100644 --- a/python/tvm/relay/to_tvm.py +++ b/python/tvm/relay/to_tvm.py @@ -155,6 +155,7 @@ def visit_local_var(self, ident: LocalVar) -> NodeRef: return self.lookup(ident) def visit_call(self, call: Call) -> NodeRef: + """Transform a ::tvm.relay.Call into an operator in the TVM graph.""" inputs = [] for arg in call.args: inputs.append(self.visit(arg).to_json()) @@ -222,7 +223,7 @@ def to_json(self) -> str: return json.dumps(json_dict) -def compile(func): +def compile_to_tvm(func): """Compile a single function to the components needed by the TVM RTS. """ diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index d9fc1eff1fd0..22c853ef512f 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -1,9 +1,9 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The type nodes of the Relay language.""" -from typing import Tuple, List +from typing import List from enum import IntEnum -from .base import Span, NodeBase, register_relay_node from tvm import expr +from .base import Span, NodeBase, register_relay_node from . import _make @@ -67,18 +67,18 @@ class Kind(IntEnum): @register_relay_node class TypeParam(Type): """A type parameter used for generic types in Relay, - see tvm/relay/type.h for more details. + see tvm/relay/type.h for more details. - A type parameter represents a type placeholder which will - be filled in later on. This allows the user to write - functions which are generic over types. + A type parameter represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. """ var: expr.Var kind: Kind span: Span def __init__(self, var: expr.Var, kind: Kind) -> None: - """Construct a TypeParam. + """Construct a TypeParam. Parameters ---------- @@ -87,7 +87,7 @@ def __init__(self, var: expr.Var, kind: Kind) -> None: kind: Kind The kind of the type parameter. - + Returns ------- type_param: TypeParam @@ -112,8 +112,7 @@ class FuncType(Type): being, a sequence of argument types, and a return type. We informally write them as: - `forall (type_params), (arg_types) -> ret_type - where type_constraints` + `forall (type_params), (arg_types) -> ret_type where type_constraints` """ type_params: List[TypeParam] type_constraints: List[TypeConstraint] @@ -121,8 +120,12 @@ class FuncType(Type): ret_type: Type span: Span - def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[TypeParam], type_constraints: List[TypeConstraint]) -> None: - """Construct a function type. + def __init__(self, + arg_types: List[Type], + ret_type: Type, + type_params: List[TypeParam], + type_constraints: List[TypeConstraint]) -> None: + """Construct a function type. Parameters ---------- @@ -130,7 +133,7 @@ def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[Type ret_type: Type type_params: list of TypeParam type_constraints: list of TypeConstraint - + Returns ------- func_type: FuncType @@ -142,8 +145,9 @@ def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[Type @register_relay_node class TypeCall(Type): - def __init__() -> None: - pass + def __init__(self, type_rel, args) -> None: + self.__init_handle_by_constructor__( + _make.TypeCall, type_rel, args) @register_relay_node diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index f169ff1b64ac..f0d60f514a37 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -6,8 +6,10 @@ from . import make as _make from . import expr as _expr + class TensorSlice(NodeGeneric, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" + def __init__(self, tensor, indices): if not isinstance(indices, tuple): indices = (indices,) @@ -31,9 +33,11 @@ def dtype(self): itervar_cls = None + @register_node class Tensor(NodeBase, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" + def __call__(self, *indices): ndim = self.ndim if len(indices) != ndim: @@ -104,6 +108,7 @@ def name(self): class Operation(NodeBase): """Represent an operation that generate a tensor""" + def output(self, index): """Get the index-th output of the operation diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index cd87fb83ec52..e225b0c5579a 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -33,7 +33,7 @@ def run(env, expr, inputs, shape): env.add("main", expr) env.transform(Monomorphize.to_pass()) main = env.lookup("main") - graph, lib, _ = to_tvm.compile(main) + graph, lib, _ = to_tvm.compile_to_tvm(main) # We use NNVM to load the graph right now because it populates node_row_ptr field. nnvm_graph = nnvm.graph.load_json(graph) module = graph_runtime.create(nnvm_graph, lib, tvm.cpu(0)) From 6ef69654368ec81b7bd327eb93fa48be9a8400f7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:18:51 -0700 Subject: [PATCH 79/88] Fix doc error --- include/tvm/relay/op.h | 8 ++++++-- include/tvm/relay/pass/alpha_eq.h | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 2d5627f2c844..f79728918086 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -149,7 +149,9 @@ class OpRegistry { const std::string& description); /*! * \brief Attach the type function corresponding to the return type. - * \param ty_func The type function to register for the return type. + * \param ty_func_name The type function name to register for the return type. + * \param type_fn The backing relation which can solve an arbitrary relation + * on variables. * \return reference to self. */ inline OpRegistry& add_type_func(const std::string& type_func_name, @@ -157,7 +159,9 @@ class OpRegistry { /*! * \brief Attach the type function corresponding to the return type. - * \param ty_func The type function to register for the return type. + * \param ty_func_name The type function name to register for the return type. + * \param type_fn The backing relation which can solve an arbitrary relation + * on variables. * \return reference to self. */ inline OpRegistry& add_type_func( diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h index 87b5164462d7..b6d98bd68940 100644 --- a/include/tvm/relay/pass/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -1,6 +1,6 @@ /*! * Copyright (c) 2018 by Contributors - * \file tvm/relay/alpha_eq.h + * \file tvm/relay/pass/alpha_eq.h * \brief Check expressions and types for structural equivalence. */ #ifndef TVM_RELAY_PASS_ALPHA_EQ_H_ From a07c956672ea78b0bd0dd6fe0e78777d649416f2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:22:48 -0700 Subject: [PATCH 80/88] Fix doc error again --- include/tvm/relay/op.h | 24 ++++++++++++------------ src/relay/op/tensor/elemwise.cc | 12 ++++++------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index f79728918086..7d0a58265565 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -149,24 +149,24 @@ class OpRegistry { const std::string& description); /*! * \brief Attach the type function corresponding to the return type. - * \param ty_func_name The type function name to register for the return type. - * \param type_fn The backing relation which can solve an arbitrary relation + * \param type_rel_name The type function name to register for the return type. + * \param type_rel The backing relation which can solve an arbitrary relation * on variables. * \return reference to self. */ - inline OpRegistry& add_type_func(const std::string& type_func_name, - TypeRelationFn type_fn); + inline OpRegistry& add_type_rel(const std::string& type_rel_name, + TypeRelationFn type_rel); /*! * \brief Attach the type function corresponding to the return type. - * \param ty_func_name The type function name to register for the return type. - * \param type_fn The backing relation which can solve an arbitrary relation + * \param type_rel_name The type function name to register for the return type. + * \param type_rel The backing relation which can solve an arbitrary relation * on variables. * \return reference to self. */ - inline OpRegistry& add_type_func( - const std::string& type_func_name, - std::function(const Array&, int)> type_fn); + inline OpRegistry& add_type_rel( + const std::string& type_rel_name, + std::function(const Array&, int)> type_rel); /*! * \brief Set the type key of attributes. @@ -356,15 +356,15 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, return *this; } -inline OpRegistry& OpRegistry::add_type_func( +inline OpRegistry& OpRegistry::add_type_rel( const std::string& type_func_name, std::function(const Array&, int)> type_fn) { auto pfunc = runtime::TypedPackedFunc(const Array&, int)>(type_fn); - return add_type_func(type_func_name, pfunc); + return add_type_rel(type_func_name, pfunc); } -inline OpRegistry& OpRegistry::add_type_func(const std::string& type_func_name, +inline OpRegistry& OpRegistry::add_type_rel(const std::string& type_func_name, TypeRelationFn type_fn) { auto type_func = TypeRelationNode::make(type_func_name, 0, type_fn); diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index d6a04773b7fa..a18259c72117 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -37,7 +37,7 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Log", IdentityRel); +.add_type_rel("Log", IdentityRel); // data : Tensor[shape, dtype] // result: Tensor[shape, dtype] @@ -51,7 +51,7 @@ RELAY_REGISTER_UNARY_OP("exp") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Exp", IdentityRel); +.add_type_rel("Exp", IdentityRel); RELAY_REGISTER_UNARY_OP("sqrt") @@ -62,7 +62,7 @@ RELAY_REGISTER_UNARY_OP("sqrt") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Sqrt", IdentityRel); +.add_type_rel("Sqrt", IdentityRel); // Addition TVM_REGISTER_API("relay.op._make.add") @@ -76,7 +76,7 @@ RELAY_REGISTER_OP("add") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_func("Broadcast", BroadcastRel); + .add_type_rel("Broadcast", BroadcastRel); // def broadcast(s1, s2): // ... @@ -97,7 +97,7 @@ RELAY_REGISTER_OP("subtract") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_func("BroadcastComp", BroadcastCompRel); + .add_type_rel("BroadcastComp", BroadcastCompRel); // def broadcast(s1, s2): // ... @@ -118,7 +118,7 @@ RELAY_REGISTER_OP("equal") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_func("BroadcastComp", BroadcastCompRel); + .add_type_rel("BroadcastComp", BroadcastCompRel); } // namespace relay } // namespace tvm From f6e51c648050ab95b730d8b67fac08e61bd9f444 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:29:09 -0700 Subject: [PATCH 81/88] Fix signed/unsigned compare --- src/relay/op/type_relations.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 2a6efbcf71e4..fb9008b3e8f2 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -91,7 +91,7 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, CHECK_EQ(larger.size(), smaller.size()); Array out_shape; - for (int i = 0; i < smaller.size(); i++) { + for (size_t i = 0; i < smaller.size(); i++) { auto left = smaller[i].as(); auto right = larger[i].as(); CHECK(left); From 1e10ba2f6ab75ae89c277c60d2dcbcb05ffa53b2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:33:07 -0700 Subject: [PATCH 82/88] Kill a few more warnings --- src/relay/pass/type_infer.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index df896fa3936a..4873b0a55580 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -86,13 +86,14 @@ struct TypeNormalizer : TypeFVisitor { CHECK(new_args.size() == normalized_args.size()); tvm::Array final_args; - for (int i = 0; i < new_args.size(); i++) { + for (size_t i = 0; i < new_args.size(); i++) { final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); } return TypeCallNode::make(ty_call->func, final_args); } else { - CHECK(false); + throw InternalError("found non type relation in the "\ + "type call function position"); } } } From ed1a84f80f306d8e1a6c73e72f0026e8d947a5c5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:48:44 -0700 Subject: [PATCH 83/88] Remove another size_t --- src/relay/ir/op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 18a647798c9e..769f26a42101 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -201,7 +201,7 @@ Op SpecializeOp(const std::string& op_name, const std::string& new_op_name, // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. - for (auto i = 0; i < type_args.size(); i++) { + for (size_t i = 0; i < type_args.size(); i++) { subst_map.Set(fn_ty->type_params[i], type_args[i]); } From fe9c4088c0ed39364e0ba4a298cef4864d9c7075 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:52:14 -0700 Subject: [PATCH 84/88] Fix warning --- src/relay/pass/unifier.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 7735ca8b0482..f1411bf9476c 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -184,7 +184,7 @@ Type TypeUnifierNode::VisitType(const Type &t1, const Type t2) { // // We flip the arguments so we hit the TypeCall and other case in there is // ever a type call. - } else if (const TypeCallNode *tvn2 = t2.as()) { + } else if (t2.as()) { return TypeFunctor::VisitType(t2, t1); } else { return TypeFunctor::VisitType(t1, t2); From 8a85b4fb15f968223c1a3e219717827908bdbb59 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 8 Sep 2018 18:25:18 -0700 Subject: [PATCH 85/88] Rewriting language reference to newest version of Relay. --- docs/api/python/index.rst | 1 - docs/api/python/relay/base.rst | 9 - docs/api/python/relay/env.rst | 6 - docs/api/python/relay/expr.rst | 36 ---- docs/api/python/relay/index.rst | 20 -- docs/api/python/relay/ir_builder.rst | 6 - docs/api/python/relay/ir_pass.rst | 3 - docs/api/python/relay/op.rst | 3 - docs/api/python/relay/to_tvm.rst | 3 - docs/api/python/relay/type.rst | 27 --- docs/conf.py | 2 +- docs/langref/relay/expressions.rst | 178 ------------------ docs/langref/relay/index.rst | 17 -- docs/langref/relay/intro.rst | 17 -- docs/langref/relay/type_system.rst | 137 -------------- include/tvm/relay/expr_visitor.h | 4 + python/tvm/relay/__init__.py | 5 +- python/tvm/relay/env.py | 32 +++- python/tvm/relay/expr.py | 2 - python/tvm/relay/ir_builder.py | 37 +++- python/tvm/relay/op/_tensor.py | 2 +- src/relay/ir/op.cc | 3 - src/relay/op/type_relations.cc | 24 +-- src/relay/pass/resolve.cc | 1 - src/relay/pass/unifier.cc | 2 +- .../relay/test_tyck_eval_integration.py | 8 +- 26 files changed, 86 insertions(+), 499 deletions(-) delete mode 100644 docs/api/python/relay/base.rst delete mode 100644 docs/api/python/relay/env.rst delete mode 100644 docs/api/python/relay/expr.rst delete mode 100644 docs/api/python/relay/index.rst delete mode 100644 docs/api/python/relay/ir_builder.rst delete mode 100644 docs/api/python/relay/ir_pass.rst delete mode 100644 docs/api/python/relay/op.rst delete mode 100644 docs/api/python/relay/to_tvm.rst delete mode 100644 docs/api/python/relay/type.rst delete mode 100644 docs/langref/relay/expressions.rst delete mode 100644 docs/langref/relay/index.rst delete mode 100644 docs/langref/relay/intro.rst delete mode 100644 docs/langref/relay/type_system.rst diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index ab411d77f4f4..59bd1795b7ec 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -23,5 +23,4 @@ Python API topi vta/index nnvm/index - relay/index hybrid diff --git a/docs/api/python/relay/base.rst b/docs/api/python/relay/base.rst deleted file mode 100644 index f0cec295ee6b..000000000000 --- a/docs/api/python/relay/base.rst +++ /dev/null @@ -1,9 +0,0 @@ -tvm.relay.base ------------ -.. automodule:: tvm.relay.base - -.. autoclass:: tvm.relay.base.NodeBase - :members: - -.. autoclass:: tvm.relay.base.Span - :members: \ No newline at end of file diff --git a/docs/api/python/relay/env.rst b/docs/api/python/relay/env.rst deleted file mode 100644 index eca7312d5bbb..000000000000 --- a/docs/api/python/relay/env.rst +++ /dev/null @@ -1,6 +0,0 @@ -tvm.relay.env ------------ -.. automodule:: tvm.relay.env - -.. autoclass:: tvm.relay.env.Environment - :members: \ No newline at end of file diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst deleted file mode 100644 index cd0cb5c308c4..000000000000 --- a/docs/api/python/relay/expr.rst +++ /dev/null @@ -1,36 +0,0 @@ -tvm.relay.expr ------------ -.. automodule:: tvm.relay.expr - -.. autoclass:: tvm.relay.expr.ExprBuilder - :members: - -.. autoclass:: tvm.relay.expr.Expr - :members: - -.. autoclass:: tvm.relay.expr.Constant - :members: - -.. autoclass:: tvm.relay.expr.Tuple - :members: - -.. autoclass:: tvm.relay.expr.LocalVar - :members: - -.. autoclass:: tvm.relay.expr.GlobalVar - :members: - -.. autoclass:: tvm.relay.expr.Param - :members: - -.. autoclass:: tvm.relay.expr.Function - :members: - -.. autoclass:: tvm.relay.expr.Call - :members: - -.. autoclass:: tvm.relay.expr.Let - :members: - -.. autoclass:: tvm.relay.expr.If - :members: \ No newline at end of file diff --git a/docs/api/python/relay/index.rst b/docs/api/python/relay/index.rst deleted file mode 100644 index 231d49df0e6d..000000000000 --- a/docs/api/python/relay/index.rst +++ /dev/null @@ -1,20 +0,0 @@ -Relay API -========= - -This document contains the Python API to the Relay frontend, optimizer, and -compiler toolchain. - -Relay is the second generation high level intermediate representation for the TVM -compiler stack. - -.. toctree:: - :maxdepth: 2 - - base - env - expr - ir_builder - ir_pass - op - to_tvm - type diff --git a/docs/api/python/relay/ir_builder.rst b/docs/api/python/relay/ir_builder.rst deleted file mode 100644 index b12e3cc6cdd1..000000000000 --- a/docs/api/python/relay/ir_builder.rst +++ /dev/null @@ -1,6 +0,0 @@ -tvm.relay.ir_builder ------------ -.. automodule:: tvm.relay.ir_builder - -.. autoclass:: tvm.relay.ir_builder.IRBuilder - :members: \ No newline at end of file diff --git a/docs/api/python/relay/ir_pass.rst b/docs/api/python/relay/ir_pass.rst deleted file mode 100644 index e2e3b432e5bd..000000000000 --- a/docs/api/python/relay/ir_pass.rst +++ /dev/null @@ -1,3 +0,0 @@ -tvm.relay.ir_pass ------------ -.. automodule:: tvm.relay.ir_pass \ No newline at end of file diff --git a/docs/api/python/relay/op.rst b/docs/api/python/relay/op.rst deleted file mode 100644 index fb8e9ce774c2..000000000000 --- a/docs/api/python/relay/op.rst +++ /dev/null @@ -1,3 +0,0 @@ -tvm.relay.op ------------ -.. automodule:: tvm.relay.op \ No newline at end of file diff --git a/docs/api/python/relay/to_tvm.rst b/docs/api/python/relay/to_tvm.rst deleted file mode 100644 index 72d01d123e0f..000000000000 --- a/docs/api/python/relay/to_tvm.rst +++ /dev/null @@ -1,3 +0,0 @@ -tvm.relay.to_tvm ------------ -.. automodule:: tvm.relay.to_tvm diff --git a/docs/api/python/relay/type.rst b/docs/api/python/relay/type.rst deleted file mode 100644 index d357df8f08ac..000000000000 --- a/docs/api/python/relay/type.rst +++ /dev/null @@ -1,27 +0,0 @@ -tvm.relay.type ------------ -.. automodule:: tvm.relay.type - -.. autoclass:: tvm.relay.type.Type - :members: - -.. autoclass:: tvm.relay.type.TensorType - :members: - -.. autoclass:: tvm.relay.type.Kind - :members: - -.. autoclass:: tvm.relay.type.TypeParam - :members: - -.. autoclass:: tvm.relay.type.TypeConstraint - :members: - -.. autoclass:: tvm.relay.type.FuncType - :members: - -.. autoclass:: tvm.relay.type.TypeCall - :members: - -.. autoclass:: tvm.relay.type.IncompleteType - :members: \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e3f7f6a82c24..717003824703 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,7 +33,7 @@ # General information about the project. project = u'tvm' author = u'%s developers' % project -copyright = u'2017, %s' % author +copyright = u'2018, %s' % author github_doc_root = 'https://github.com/tqchen/tvm/tree/master/docs/' # add markdown parser diff --git a/docs/langref/relay/expressions.rst b/docs/langref/relay/expressions.rst deleted file mode 100644 index 37dc62c6bc24..000000000000 --- a/docs/langref/relay/expressions.rst +++ /dev/null @@ -1,178 +0,0 @@ -================== -Expressions -================== - -Relay's IR is a pure expression oriented language, that has a -dataflow fragment and structured control flow. Although Relay's -representation is a tree, it is possible to view the dataflow -fragments as graph for purposes of writing and expressing -transformations. - -The below sections make an attempt to clearly split the dataflow -fragment from the control fragment. - -================== -Dataflow Expressions -================== - -First we will cover the set of nodes which do not involve control flow, -this fragment of the language is semantically equivalent to pure -computation graphs without control flow. - -Constants -~~~~~~~~~ -Relay programs can contain constant Tensor values, since in Relay all -values are either Tensors, Products, or Closures. We will discuss the -later two later, but we represent Tensor constants as `tvm.NDArray`, -allowing us to utilize normal operators for constant evaluation. - - -Constructors -~~~~~~~~ - -Relay supports a handful of constructors which we will cover below. A -constructor enables programs to build new values from arbitrary Relay -expressions. - - -We support four types of literals, literals are type polymorphic and can -assigned any base type. If we can not solve for a concrete type we apply -a defaulting rule. - -We support signed and unsigned integers, floating point numbers, booleans, -and tensor literals. - -The base type literals are designed to closely model literals in TVM's -expressions langauge. - -### Boolean Literals -TODO: don't have these in any form right now - -### Integer Literals -TODO: don't have these in any form right now - -Tensor Constructor -~~~~~~~~~~~~~~~ - -A tensor literal allows us to build a Tensor from other expressions. - -TODO: Example here - - -Tuple Constructor -~~~~~~~~~~~~~~~ - -We support tuple constructors which allows us to build a fixed-k sized -sequence of heterogenous data. These tuples match closely to Python's -and enable efficient projection of their members due to their fixed length. - - (a, b, c) : Tuple - - (a + b + c, d) : Tuple, Tensor> - -Function -~~~~~~~~ - -A function node represents a function, it contains a seqeuence of -parameters, a return type, and a body. - - fun (x : Float, y: Float) -> Float { x + y } - -Functions are first class in Relay, and can be used in any expression -position. Functions are the same as global functions, but do not have -an explicit name. You can use a function in conjunction with a let -binding to define locally recursive functions. - - let fact = fun (x : Float) -> Float { - if (x == 0) { - 0 - } else { - x * fact(x - 1) - }; - fact(10) - -Identifiers -~~~~~~~~~~~ - -All of the identifiers are valid expressions, you can use a local identifier, -global identifier, or intrinsic identifier anywhere an expression may appear. - -For example the below fragment of code is a valid expression. - - %ret = @global(intrinsic, %local) - -Let Binding -~~~~~~~~~~~ - -An immutable variable binding, allows the user to bind an -expression to a name. A let binding contains a local identifier, -an optional type, a value, and body expression which may -reference the bound identifier. - -We will first introduce a single binding with no type -anntoations:: - let %x = %a + %b; - x - -The value of a let binding is the value of the final expression -after evaluating the bindings it depends on. - -A user can write a sequence of let bindings, we can view -these blocks and pure dataflow -single binding. These blocks are pure dataflow, and can -be evaluated in any order, reordered up to dataflow. - -We support a sequence of bindings followed by a body which -is the continutation after executing the sequence of bindings. - -I believe this representation will be easier to manipulate then -the mixed dataflow/control flow comptuation graphs. -Data flow and control flow are strictly seperated in this representation -and we can easily syntactically discriminate. When in ANF there should only be -general control flow between `Assignment` nodes and not within the values bound -in bindings. - -This representation also makes it easy to apply reverse more since -sequences of assignments where the only control flow is call instructions -are treated by the algorithm uniformly, and each control flow construct -must be handled individualy. - -TODO Add Ref, ReadRef, WriteRef, Projection, - -Gradient -~~~~~~~~ - -The `Reverse` acts as a marker node, when the compiler encounters it -we will apply the reverse mode transformation to the enclosed function. - -We will employ static analysis and constant evaluation in order to -simplify the node's argument to a known function call target. - - -You can compute the reverse node of a function node like so: - -Cast -~~~~~ - -Cast the type of the `node` to `ty`. - -======================= -Control Flow Expression -======================= -Control flow expressions change network topology based on values -computed by previous expressions. - -Call -~~~~ - -Terms with function types in Relay are "callable", that can be invoked like -a function in a typical programming language by supplying a set of arguments. - -Instrinsics with functions types, definitions, and functions are all callable. - -If-Then-Else -~~~~~~~~~~~~ - -Relay has a simple if/then/else expression which allows programs to branch -on a single control value which must be of type `Bool`, i.e a zero-rank -tensor of booleans. diff --git a/docs/langref/relay/index.rst b/docs/langref/relay/index.rst deleted file mode 100644 index 617e745acdfc..000000000000 --- a/docs/langref/relay/index.rst +++ /dev/null @@ -1,17 +0,0 @@ -Relay Language Reference -======================== - -This document is a work in progress language reference describing -Relay, TVM's high level intermediate representation. The name is an -allusion to interneurons which are often referred to as intermediate, -or relay neurons. - -We will continually iterate on this document as we evolve the new IR -and update accordingly. - -.. toctree:: - :maxdepth: 2 - - intro - expressions - type_system diff --git a/docs/langref/relay/intro.rst b/docs/langref/relay/intro.rst deleted file mode 100644 index 617e745acdfc..000000000000 --- a/docs/langref/relay/intro.rst +++ /dev/null @@ -1,17 +0,0 @@ -Relay Language Reference -======================== - -This document is a work in progress language reference describing -Relay, TVM's high level intermediate representation. The name is an -allusion to interneurons which are often referred to as intermediate, -or relay neurons. - -We will continually iterate on this document as we evolve the new IR -and update accordingly. - -.. toctree:: - :maxdepth: 2 - - intro - expressions - type_system diff --git a/docs/langref/relay/type_system.rst b/docs/langref/relay/type_system.rst deleted file mode 100644 index 91a634431d7c..000000000000 --- a/docs/langref/relay/type_system.rst +++ /dev/null @@ -1,137 +0,0 @@ -================== -Type System -================== - -We have briefly introduced types while detailing the the expression language -of Relay, but have fully laid out the type system. - -Although the majority of Relay programs require no type annotations, Relay -is statically typed. Each expression in Relay has a precisely known type. - -You might ask why we want a statically typed IR, there are multiple advantages. -- efficient layout and code generation for tensors -- TODO -- debugging transformations (most program transformations should be type perserving) - -We are able to omit these type annotations by a process known as type inference. -Type inference is a technique that has its roots in the programming language -community, and can be viewed as a method for generalizing shape inference to -run over arbitrary user programs. - -Static typing means we know before executing the program properties about -the values it manipulates. Static types are useful for compiler optimization -because they communicate properties about the data we manipulate, such as -runtime shape, data layout, storage. - -Most current IRs use "shape inference" to recover Tensor dimensions from the user -provided program. Machine learning users have enjoyed shape inference for -tensors because it allows them to generate performant code without giving up -on the expressivity of the input language. - -Because Relay is intended as an IR we require *some* type information to provide -full inference. We don't believe this to be an issue as many of the IR builder -inferfaces require some type information, or can generate IR based on their own -higher level inferences. - -We view this limited shape inference as a simpler form of type -inference. Instead of relying on an ad-hoc procedure for recovering type -information from a potentially dynamic program, we apply ideas from compiler and IR design. - -Below we briefly dicsuss the different kinds of types in Relay. - -===== -Types -===== - -BaseType -~~~~~~~~~~ -Relay has a notion of a BaseType, which captures the set of types -that can be stored in a Tensor. Relay's base types map to the set -of types supported by TVM. - -Each of the base types can be parametrized by number of bits, and -lanes for vectorization purposes. We support four base types any:`Bool`, -any:`Int` - -Type Variables -~~~~~~~~~~~~~~ - -Type Parameters -~~~~~~ -TODO: type parameter - -Kind -~~~~ - -Function Types -~~~~~~~~~~ -TODO: rename function type? - -TypeQuantifier -~~~~~~~~~~~~~~ -TODO - -Placeholders -~~~~~~~~~~~~ - -TODO - -Tuple Types -~~~~~~~~~~~~~ - -Reference Types -~~~~~~~~~~~~~~~ - -A reference type is simply a mutable memory location, since Relay is a pure -language by default we need a way to introduce limited mutability. In this -case mutable data is clearly marked in the type system as a reference type. - - Ref - -Tensor Type -~~~~~~~~~~~ - -Tensor values in Relay are typed with tensor types. A tensor type is -parametrized by a data type, and shape. The data type must be a base -type as enforced by the kind checking rules described in TODO. - -This restriction importantly means - -The shape may be any valid Relay shape as described in the below -section on shapes. - - -====== -Shapes -====== - -Shape Singleton -~~~~~~~~~~~~~~~ -I don't like this name - -ShapeAttr -~~~~~~~~~ -TODO - -ShapeProjection -~~~~~~~~~~~~~~~ -TODO - -ShapeBinaryOp -~~~~~~~~~~~~~ - -enum ShapeOp : int { - SHPLUS = 0, - SHSUB = 1, - SHMUL = 2, - SHDIV = 3 -}; - - -Shape Sequence -~~~~~~~~ -A sequence of shapes ... - - -ShapeBroadcast -~~~~~~~~~~~~~~ diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 6f2a7f98542a..0febad503b12 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -75,6 +75,10 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { return GetRef(op); } + Expr VisitExpr_(const ConstantNode* op) override { + return GetRef(op); + } + Expr VisitExpr_(const GlobalVarNode* op) override { return GetRef(op); } diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index aae019c8d9c1..c254c7e9ce7a 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -3,7 +3,10 @@ from . import base from . import type as tpe from . import expr - +from . import to_tvm +from . import env +from . import ir_pass +from . import ir_builder # Operators from .op import Op from .op.tensor import * diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index beef6fd1a62c..93cbe1bca284 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -1,5 +1,5 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import -"""A global environment storing everything needed to interpret or compile a Realy program.""" +"""A global environment storing everything needed to interpret or compile a Relay program.""" from .base import register_relay_node, NodeBase from . import _make from . import _env @@ -10,25 +10,55 @@ class Environment(NodeBase): options and more. """ def __init__(self, funcs) -> None: + """Construct an environment. + + Parameters + ------ + funcs: list of relay.Function + + Returns + ------ + env: A new environment containing :py:class:`~relay.env.Environment`. + """ self.__init_handle_by_constructor__(_make.Environment, funcs) def add(self, var, func) -> None: + """Add a function to the environment. + + Parameters + --------- + var: GlobalVar + The global variable which names the function. + + func: Function + The function. + """ if isinstance(var, str): var = _env.Environment_GetGlobalVar(self, var) _env.Environment_Add(self, var, func) def merge(self, other): + """Merge two environments. + + Parameters + ---------- + other: Environment + The environment to merge into the current Environment. + """ return _env.Environment_Merge(self, other) def global_var(self, var): + """Get a global variable by name.""" return _env.Environment_GetGlobalVar(self, var) def lookup(self, var): + """Lookup a global function by name or by variable.""" if isinstance(var, str): return _env.Environment_Lookup_str(self, var) else: return _env.Environment_Lookup(self, var) def transform(self, transformer): + """Apply a transformer function to the environment.""" _env.Environment_Transform(self, transformer) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 748b2aa1e282..3cdaed89d2fb 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -15,8 +15,6 @@ class ExprBuilder(): def __call__(self, *args, **kwargs): converted_args = [] for arg in args: - import pdb - pdb.set_trace() if isinstance(arg, Param): converted_args.append(arg.var) else: diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index a271a537b290..c0c2e76c1157 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -79,7 +79,7 @@ def to_func(self): #pylint: disable=invalid-name def _mk_let(bindings, ret_value): let_expr = ret_value - for var, value, ty in reversed(list(bindings.items())): + for var, (value, ty) in reversed(list(bindings.items())): let_expr = Let(var, value, let_expr, ty) return let_expr @@ -114,7 +114,7 @@ def exit_scope(self): return bindings, scopes, params, ret_value #pylint: disable=invalid-name - def bind(self, name, ty, value): + def bind(self, name, value, ty): lv = LocalVar(name) self.scopes[-1][name] = lv self.bindings[-1][lv] = (value, ty) @@ -127,16 +127,35 @@ def let(self, name, value, value_type=None): if not isinstance(value, Expr): value = into_ast(value) - return self.bind(name, value_type, value) + return self.bind(name, value, value_type) + + def _convert_params(self, raw_params): + relay_params = [] + for raw_param in raw_params: + if isinstance(raw_param, Param): + var = raw_param.var + param = raw_param + elif isinstance(raw_param, tuple): + var, ty = raw_param + if isinstance(var, str): + var = LocalVar(var) + param = Param(var, ty) + elif isinstance(param, str): + var = LocalVar(raw_param) + ty = None + param = Param(var, ty) + else: + raise Exception("unknown parameter type") + + self.scopes[-1][var.name_hint] = var + relay_params.append(param) + + return relay_params def function(self, *params): """Construct a Relay function.""" - relay_params = [] - for param in params: - name = param.var - ty = param.type - self.scopes[-1][name.name_hint] = name - relay_params.append(Param(name, ty)) + + relay_params = self._convert_params(params) # self.params.append(relay_params) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 4427faa6a3a6..2a0ecc6c8550 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -49,7 +49,7 @@ def func_ty_to_placeholders(func_ty): # return [schedule, Inputs + [Output]] #pylint: disable=duplicate-argument-name -def add_compiler(_, func_type, *_): +def add_compiler(_, func_type, *__): """The compilation code for the TVM compiler.""" inputs, _ = func_ty_to_placeholders(func_type) # op = lookup_in_topi(op_name) diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 769f26a42101..7c005acb8648 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -155,9 +155,7 @@ Module CompileOpsToModule(const std::vector& op_names) { if (!IsGeneric(op->op_type)) { auto compiler = compiler_map[op]; - std::cout << "ABOVE CALL" << std::endl; tvm::Array pair = compiler(op->name, op->op_type); - std::cout << "BELOW CALL" << std::endl; // TODO(@jroesch): I can't pass strings across what should be the // interface here. tvm::Array triple = {LocalVarNode::make(op->name), pair[0], @@ -183,7 +181,6 @@ TVM_REGISTER_API("relay.op._CompileOpsToModule") for (auto i = 0; i < args.num_args; i++) { names.push_back(args[i]); } - std::cout << "Right here" << std::endl; *ret = CompileOpsToModule(names); }); diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index fb9008b3e8f2..e2b2cba1e0ef 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -22,9 +22,9 @@ TensorType as_ttype(const Type& t) { // TODO(@jroesch) what size value do we extract? int to_int(const tvm::Expr& e) { + CHECK(e.defined()); auto imm = e.as(); - CHECK(imm); - std::cout << "TYPE: " << imm << imm->type << std::endl; + CHECK(imm) << "TYPE: " << imm << imm->type << std::endl; return imm->value; } @@ -53,17 +53,17 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); auto full_len = static_cast(std::max(sh1.size(), sh2.size())); - std::cout << "Length" << suffix_len << full_len - << (full_len - suffix_len - 1) << std::endl; - auto lower_bound = full_len - suffix_len - 1; + auto rev_sh1 = sh1.rbegin(); + auto rev_sh2 = sh2.rbegin(); - for (int64_t i = full_len - 1; i > lower_bound; i--) { - std::cout << "Index i=" << i << std::endl; - auto dim1 = to_int(sh1[i]); - auto dim2 = to_int(sh2[i]); - if (dim1 != dim2) { - CHECK(false); + while (rev_sh1 != sh1.rend() && rev_sh2 != sh2.rend()) { + auto dim1 = to_int(*rev_sh1); + auto dim2 = to_int(*rev_sh2); + if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) { + CHECK(false) << "Dimension mistmatch " << "dim1: " << dim1 << " dim2: " << dim2 << std::endl; } + rev_sh1++; + rev_sh2++; } Array larger; @@ -106,9 +106,9 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, Array BroadcastRel(const Array& types, int num_args) { CHECK_EQ(types.size(), 3); + RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1] << "Out: " << types[2] << std::endl; if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { - std::cout << t1->dtype << t2->dtype << std::endl; CHECK_EQ(t1->dtype, t2->dtype); return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; } diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index f513e36c9a30..bc63d939959e 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -53,7 +53,6 @@ struct ResolveTypeExpr : ExprFVisitor { // term, then resolve e's old type and write // it back into the new node. auto new_e = ExprFVisitor::VisitExpr(e); - std::cout << e << std::endl; CHECK(e->checked_type_.defined()); auto resolved_cty = VisitType(e->checked_type_); new_e->checked_type_ = resolved_cty; diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index f1411bf9476c..f5e337eb17f7 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -294,7 +294,7 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { // Type unified_shape = this->VisitType(tt1->shape, tt2->shape); return rt2; } catch (const UnificationError &err) { - std::cout << "Need to check constraint " << tt1->shape << " = " + CHECK(false) << "Need to check constraint " << tt1->shape << " = " << tt2->shape << std::endl; } diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index e225b0c5579a..f9a3d098a3e2 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -104,8 +104,8 @@ def test_add_broadcast_op(): ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert has_type(func.to_func(), expected_ty) - x_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) - y_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + x_data = tvm.nd.array(np.random.rand(10, 4).astype('float32')) + y_data = tvm.nd.array(np.random.rand(5, 10, 1).astype('float32')) result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 10, 4)) np.testing.assert_allclose( x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) @@ -171,8 +171,8 @@ def f(n: i32, data: f32) -> f32 { # to execute this. if __name__ == "__main__": - # test_monomorphic_let() - # test_single_op() + test_monomorphic_let() + test_single_op() test_add_op() test_add_broadcast_op() # test_dual_op() From d626dfd9a126064ec7a48b763f5cd73a9496b0d3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 13 Sep 2018 16:16:37 -0700 Subject: [PATCH 86/88] Revert "Port docs from previous Relay version" This reverts commit d47f637a14b4cbfcbd356cb55df78df3b5093b88. --- tutorials/relay/implement_fma_transform.py | 141 --------------------- 1 file changed, 141 deletions(-) delete mode 100644 tutorials/relay/implement_fma_transform.py diff --git a/tutorials/relay/implement_fma_transform.py b/tutorials/relay/implement_fma_transform.py deleted file mode 100644 index 8c04e70aa846..000000000000 --- a/tutorials/relay/implement_fma_transform.py +++ /dev/null @@ -1,141 +0,0 @@ -"""How to use Relay to implement a simple two-operator fusion pass. -================================== -**Author**: `Jared Roesch `_ - -In this tutorial, we will demonstrate how to write a fusion pass for -the Relay IR. We demonstrate many Relay features including defining a -new operator, a program transform, the NNVM compatibility layer, -and executing the original and transformed programs on the Relay -evaluator and TVM runtime system. -""" - -################################################################ -# Introduction -# ------------------------- -# -# In this tutorial, we will demonstrate how to write a fusion pass for -# the Relay IR. We demonstrate many Relay features including defining a -# new operator, a program transform, the NNVM compatibility layer, -# and executing the original and transformed programs on the Relay -# evaluator and TVM runtime system. - -from typing import Any, Dict - -import numpy as np -import tvm -import topi - -from relay import ir, make as mk -from relay.ir import OperatorId -from relay.opt import ItemVisitor, ExprVisitor -from relay.frontend.nnvm import Variable, symbol -from relay.frontend.nnvm import compiler -from relay.frontend.global_env import get_env -from relay.operators.register import func_ty_to_placeholders, register_op -from relay.eval import defn_to_pyfunc -from relay.tyck import check_expr - -class ExprAtVisitor(ExprVisitor): - """A demo visitor which adds a new traversal strategy.""" - expr_map: Dict[ir.LocalId, ir.Expr] - - def __init__(self): - self.expr_map = {} - - def expr_at(self,id: ir.LocalId) -> ir.Expr: - try: - return self.expr_map[id] - except KeyError: - return id - - def visit_let(self, let: ir.Let) -> ir.Expr: - self.expr_map[let.id] = let.value - return super().visit_let(let) - -# let x = 1 + 1; -# ... x will map to 1 + 1 - -class FuseTwo(ExprAtVisitor): - """Rewrite b(a(x, y), z) into ab(x, y, z). """ - def __init__(self, a: OperatorId, b: OperatorId, a_and_b: OperatorId) -> None: - self.a = a - self.b = b - self.a_and_b = a_and_b - super().__init__() - - def visit_call(self, call: ir.Call) -> ir.Expr: - func = call.fn - if func == self.b: - assert len(call.args) == 2 # An assumption of this fusion code. - arg0 = self.expr_at(call.args[0]) - arg1 = self.expr_at(call.args[1]) - if isinstance(arg0, ir.Call) and arg0.fn == self.a: - new_call = mk.Call(self.a_and_b, arg0.args[:] + [arg1]) - elif isinstance(arg1, ir.Call) and arg1.fn == self.a: - new_call = mk.Call(self.a_and_b, arg1.args[:] + [arg0]) - else: - new_call = super().visit_call(call) - - return new_call - else: - return super().visit_call(call) - -def fma_compile(op_name: str, func_ty: ir.Type, attrs: ir.Attributes=None) -> Any: - Inputs, ret_ty = func_ty_to_placeholders(func_ty) - x, y, z = Inputs - Output = topi.multiply(topi.add(x, y), z) - # this is not a python function call, but builds an AST - schedule = tvm.create_schedule(Output.op) - return [schedule, Inputs + [Output]] - - -def register_fma(env: Any) -> None: - """Register TOPI's elementwise broadcast addition for the `+` operator.""" - shape = mk.TypeParam("s", ir.Kind.Shape) - bt = mk.TypeParam("bt", ir.Kind.BaseType) - in_out_type = mk.TensorType(bt, shape) - fma_type = mk.TypeQuantifier(bt, mk.TypeQuantifier(shape, mk.TypeArrow([in_out_type, in_out_type, in_out_type], in_out_type))) - # forall (bt: BaseTYpe) (s : Shape), Tensor[bt, s] -> Tensor[bt, s] -> Tensor[bt, s] - # TODO: no reverse mode - register_op(env, 'fma', fma_type, compiler=fma_compile) - -# Get the global environment for demo purposes. -env = get_env() - -register_fma(env) - -# A small helper which applies just our transform to the Relay expression. -def transform(e): - fuse = FuseTwo(env.add_id(), env.mul_id(), env.operator_id('fma')) - e = fuse.visit(e) - # Now let's use the type checker to make sure we didn't make a mistake. - check_expr(env, e) - return e - -# We will use NNVM frontend. -x = Variable('x') -y = Variable('y') -z = x * (x + y) - -relay_func = compiler.to_relay(z) - -print(f"Relay Function:\n{compiler.pp(relay_func)}") - -xform_func = transform(relay_func) - -print(f"Transformed Function:\n{compiler.pp(xform_func)}") - -# Use the evaluator. -norm = defn_to_pyfunc(env, relay_func) -xform = defn_to_pyfunc(env, xform_func) - -x = np.random.uniform(size=(10, 5, 10)).astype('float32') -y = np.random.uniform(size=(10, 5, 10)).astype('float32') - -norm_out = norm(x, y).asnumpy() -xform_out = xform(x, y).asnumpy() - -np.testing.assert_allclose(norm_out, xform_out) - -# Use the TVM runtime. - From cde08544178c6930c8578b8068eaa7b122915062 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Fri, 14 Sep 2018 01:15:22 -0700 Subject: [PATCH 87/88] change fvisitor to mutator to match tvm name convention/type signature --- include/tvm/relay/expr_visitor.h | 61 ++++++++++++++++++-------------- src/relay/pass/resolve.cc | 4 +-- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 0febad503b12..9d65247630c6 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -69,35 +69,42 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { virtual void VisitType(const Type& t) {} }; -class ExprFVisitor : public ::tvm::relay::ExprFunctor { +// Note: although IRMutator in TVM return the old expr if the result is structurally unchanged (hash consing), +// we do not does hash consing for ExprMutator, as relay is base on tree rather then graph - even if the old expression is returned, +// it will be treated as a brand new one across many place. +class ExprMutator : public ::tvm::relay::ExprFunctor { public: - Expr VisitExpr_(const LocalVarNode* op) override { - return GetRef(op); + Expr Mutate(const Expr & self) { + return this->VisitExpr(self, self); } - Expr VisitExpr_(const ConstantNode* op) override { - return GetRef(op); + Expr VisitExpr_(const LocalVarNode* op, const Expr & self) override { + return self; } - Expr VisitExpr_(const GlobalVarNode* op) override { - return GetRef(op); + Expr VisitExpr_(const ConstantNode* op, const Expr & self) override { + return self; } - Expr VisitExpr_(const OpNode* op) override { - return GetRef(op); + Expr VisitExpr_(const GlobalVarNode* op, const Expr & self) override { + return self; } - Expr VisitExpr_(const TupleNode* op) override { + Expr VisitExpr_(const OpNode* op, const Expr & self) override { + return self; + } + + Expr VisitExpr_(const TupleNode* op, const Expr & self) override { tvm::Array fields; for (auto field : op->fields) { - fields.push_back(this->VisitExpr(field)); + fields.push_back(this->Mutate(field)); } return TupleNode::make(fields); } - Expr VisitExpr_(const ParamNode* op) override { - Expr var_expr = this->VisitExpr(op->var); + Expr VisitExpr_(const ParamNode* op, const Expr & self) override { + Expr var_expr = this->Mutate(op->var); if (const LocalVarNode* var_node = var_expr.as()) { auto var = GetRef(var_node); auto type = this->VisitType(op->type); @@ -107,7 +114,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { } } - Expr VisitExpr_(const FunctionNode* op) override { + Expr VisitExpr_(const FunctionNode* op, const Expr & self) override { tvm::Array ty_params; for (auto ty : op->type_params) { @@ -122,7 +129,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { tvm::Array params; for (auto param : op->params) { - Expr param_expr = this->VisitExpr(param); + Expr param_expr = this->Mutate(param); if (const ParamNode* param_node = param_expr.as()) { auto param = GetRef(param_node); params.push_back(param); @@ -132,12 +139,12 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { } auto ret_type = this->VisitType(op->ret_type); - auto body = this->VisitExpr(op->body); + auto body = this->Mutate(op->body); return FunctionNode::make(params, ret_type, body, ty_params); } - Expr VisitExpr_(const CallNode* call_node) override { - auto fn = this->VisitExpr(call_node->op); + Expr VisitExpr_(const CallNode* call_node, const Expr & self) override { + auto fn = this->Mutate(call_node->op); tvm::Array ty_args; for (auto ty_arg : call_node->type_args) { @@ -147,7 +154,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { tvm::Array call_args; for (auto arg : call_node->args) { - call_args.push_back(this->VisitExpr(arg)); + call_args.push_back(this->Mutate(arg)); } auto call = CallNode::make(fn, call_args, call_node->attrs, ty_args); @@ -155,23 +162,23 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { return call; } - Expr VisitExpr_(const LetNode* op) override { - Expr var_expr = this->VisitExpr(op->var); + Expr VisitExpr_(const LetNode* op, const Expr & self) override { + Expr var_expr = this->Mutate(op->var); if (const LocalVarNode* var_node = var_expr.as()) { auto var = GetRef(var_node); auto type = this->VisitType(op->value_type); - auto value = this->VisitExpr(op->value); - auto body = this->VisitExpr(op->body); + auto value = this->Mutate(op->value); + auto body = this->Mutate(op->body); return LetNode::make(var, value, body, type); } else { throw dmlc::Error("the default let visitor has error"); } } - Expr VisitExpr_(const IfNode* op) override { - auto guard = this->VisitExpr(op->cond); - auto true_b = this->VisitExpr(op->true_value); - auto false_b = this->VisitExpr(op->false_value); + Expr VisitExpr_(const IfNode* op, const Expr & self) override { + auto guard = this->Mutate(op->cond); + auto true_b = this->Mutate(op->true_value); + auto false_b = this->Mutate(op->false_value); return IfNode::make(guard, true_b, false_b); } diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index bc63d939959e..5549388177b9 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -33,7 +33,7 @@ struct ResolveTypeType : TypeFVisitor { } }; -struct ResolveTypeExpr : ExprFVisitor { +struct ResolveTypeExpr : ExprMutator { const TypeUnifier &unifier; explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {} @@ -52,7 +52,7 @@ struct ResolveTypeExpr : ExprFVisitor { // We will visit e like normal building a new // term, then resolve e's old type and write // it back into the new node. - auto new_e = ExprFVisitor::VisitExpr(e); + auto new_e = ExprMutator::Mutate(e); CHECK(e->checked_type_.defined()); auto resolved_cty = VisitType(e->checked_type_); new_e->checked_type_ = resolved_cty; From b77082f7824ae6a7287b88bf1f5d844cb43be2a2 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Fri, 14 Sep 2018 01:50:02 -0700 Subject: [PATCH 88/88] do the same for typefvisitor --- src/relay/pass/resolve.cc | 14 +++++------ src/relay/pass/type_infer.cc | 10 ++++---- src/relay/pass/type_subst.cc | 8 +++---- src/relay/pass/type_visitor.h | 45 +++++++++++++++++++---------------- src/relay/pass/unifier.cc | 8 +++---- 5 files changed, 43 insertions(+), 42 deletions(-) diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index 5549388177b9..fd07083b689d 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -13,23 +13,23 @@ namespace tvm { namespace relay { // TODO(@jroesch): We should probably generalize the subst code. -struct ResolveTypeType : TypeFVisitor { +struct ResolveTypeType : TypeMutator { const TypeUnifier &unifier; explicit ResolveTypeType(const TypeUnifier &unifier) : unifier(unifier) {} - Type VisitType(const Type &t) override { + Type Mutate(const Type &t) override { if (!t.defined()) { auto inc_ty = IncompleteTypeNode::make(TypeParamNode::Kind::kType); unifier->insert(inc_ty); return inc_ty; } else { - return TypeFVisitor::VisitType(t); + return TypeMutator::Mutate(t); } } - Type VisitType_(const IncompleteTypeNode *op) override { - return unifier->subst(GetRef(op)); + Type VisitType_(const IncompleteTypeNode *op, const Type & self) override { + return unifier->subst(self); } }; @@ -60,13 +60,13 @@ struct ResolveTypeExpr : ExprMutator { } Type VisitType(const Type &t) { - return ResolveTypeType(unifier).VisitType(t); + return ResolveTypeType(unifier).Mutate(t); } }; Type Resolve(const TypeUnifier &unifier, const Type &ty) { CHECK(ty.defined()); - return ResolveTypeType(unifier).VisitType(ty); + return ResolveTypeType(unifier).Mutate(ty); } Expr Resolve(const TypeUnifier &unifier, const Expr &expr) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 4873b0a55580..1ed4c57f1798 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -58,17 +58,15 @@ struct TypeContext { }; }; -struct TypeNormalizer : TypeFVisitor { +struct TypeNormalizer : TypeMutator { TypeUnifier unifier; explicit TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} - Type VisitType_(const TypeCallNode *ty_call_node) { - auto ty_call = GetRef(ty_call_node); - + Type VisitType_(const TypeCallNode *ty_call, const Expr & self) { Array normalized_args; for (auto arg : ty_call->args) { - normalized_args.push_back(VisitType(arg)); + normalized_args.push_back(Mutate(arg)); } auto all_concrete = true; @@ -164,7 +162,7 @@ TypeInferencer::TypeInferencer(Environment env) : env(env) { Type TypeInferencer::Normalize(const Type &t) { auto nt = this->resolve(t); auto normalizer = TypeNormalizer(this->unifier); - return normalizer.VisitType(nt); + return normalizer.Mutate(nt); } CheckedExpr TypeInferencer::Infer(const Expr &expr) { diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc index 91713976bcaa..99718447fa02 100644 --- a/src/relay/pass/type_subst.cc +++ b/src/relay/pass/type_subst.cc @@ -9,13 +9,13 @@ namespace tvm { namespace relay { -struct TypeSubstV : TypeFVisitor { +struct TypeSubstV : TypeMutator { tvm::Map subst_map; explicit TypeSubstV(tvm::Map subst_map) : subst_map(subst_map) {} - Type VisitType_(const TypeParamNode *op) override { + Type VisitType_(const TypeParamNode *op, const Type & self) override { auto id = GetRef(op); if (subst_map.find(id) != subst_map.end()) { return this->subst_map[id]; @@ -27,12 +27,12 @@ struct TypeSubstV : TypeFVisitor { Type TypeSubst(const Type &type, const TypeParam &target, const Type &subst) { TypeSubstV ty_sub({ {target, subst} }); - return ty_sub.VisitType(type); + return ty_sub.Mutate(type); } Type TypeSubst(const Type &type, tvm::Map subst_map) { TypeSubstV ty_sub(subst_map); - return ty_sub.VisitType(type); + return ty_sub.Mutate(type); } } // namespace relay diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index d65d6c567b23..d9f6edae5de4 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -52,17 +52,20 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { }; // A functional visitor for rebuilding an AST in place. -struct TypeFVisitor : TypeFunctor { - Type VisitType_(const TensorTypeNode* op) override { +struct TypeMutator : TypeFunctor { + virtual Type Mutate(const Type & self) { + return this->VisitType(self, self); + } + Type VisitType_(const TensorTypeNode* op, const Type & self) override { // TODO(@jroesch): maybe we should recursively visit - return TensorTypeNode::make(op->shape, op->dtype); + return self; } - Type VisitType_(const TypeParamNode* op) override { - return GetRef(op); + Type VisitType_(const TypeParamNode* op, const Type & self) override { + return self; } - Type VisitType_(const FuncTypeNode* op) override { + Type VisitType_(const FuncTypeNode* op, const Type & self) override { // TODO(@jroesch): handle poly // auto new_id = this->VisitType(op->var); @@ -72,36 +75,36 @@ struct TypeFVisitor : TypeFunctor { std::vector args; for (auto arg_type : op->arg_types) { - args.push_back(VisitType(arg_type)); + args.push_back(this->Mutate(arg_type)); } - return FuncTypeNode::make(tvm::Array(args), VisitType(op->ret_type), + return FuncTypeNode::make(tvm::Array(args), Mutate(op->ret_type), {}, {}); // fix me } - Type VisitType_(const TupleTypeNode* op) override { - std::vector new_fields; - for (const Type& t : op->fields) { - new_fields.push_back(this->VisitType(t)); - } - return TupleTypeNode::make(new_fields); + Type VisitType_(const TupleTypeNode* op, const Type & self) override { + std::vector new_fields; + for (const Type& t : op->fields) { + new_fields.push_back(this->Mutate(t)); } + return TupleTypeNode::make(new_fields); + } - Type VisitType_(const TypeRelationNode* op) override { - return GetRef(op); + Type VisitType_(const TypeRelationNode* op, const Type & self) override { + return self; } - Type VisitType_(const TypeCallNode* op) override { - auto func = this->VisitType(op->func); + Type VisitType_(const TypeCallNode* op, const Type & self) override { + auto func = this->Mutate(op->func); std::vector new_args; for (const Type& t : op->args) { - new_args.push_back(this->VisitType(t)); + new_args.push_back(this->Mutate(t)); } return TypeCallNode::make(func, new_args); } - Type VisitType_(const IncompleteTypeNode* op) override { - return GetRef(op); + Type VisitType_(const IncompleteTypeNode* op, const Type & self) override { + return self; } }; diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index f5e337eb17f7..1aec76629b35 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -146,26 +146,26 @@ Type TypeUnifierNode::unify(const Type &t1, const Type &t2) { return unified; } -struct IncompleteTypeSubst : TypeFVisitor { +struct IncompleteTypeSubst : TypeMutator { const TypeUnifierNode *unifier; IncompleteTypeSubst(const TypeUnifierNode *unifier) : unifier(unifier) {} // type var: look it up in the type map and recurse - Type VisitType_(const IncompleteTypeNode *op) override { + Type VisitType_(const IncompleteTypeNode *op, const Type & self) override { auto tv = GetRef(op); auto parent = unifier->uf->find(tv); if (parent == tv) { return tv; } - return this->VisitType(parent); + return this->Mutate(parent); } }; Type TypeUnifierNode::subst(const Type &t) { IncompleteTypeSubst tvsubst(this); // normalize first so substitutions in quantifiers will be correct - Type ret = tvsubst.VisitType(t); + Type ret = tvsubst.Mutate(t); // if (!check_kind(ret)) { // std::stringstream ss; // ss << "Invalid Kinds in substituted type!";