Skip to content

Commit

Permalink
[REFACTOR][IR] Initialize Unified IR Expr Data Structure
Browse files Browse the repository at this point in the history
This PR moves a few base types from relay and low-level Expr into the ir sub-folder.
These classes will serve as a common type system across the stack.

Rationale:

- PrimExpr for low-level expressions
- RelayExpr for advanced features, including Function definition.
- Introduce BaseFunc to host all functions, including future PrimFunc(low-level expr functions, subject to discussion).

This is a minimum change we can do to unify the classes into a common hierarchy.
The main data structure that are variant specific will still be kept in the sub-namespaces.
We only include classes that is needed to allow a common Module class.
- BaseFunc
- GlobalVar
- Possibly part of ADT in a subsequent PR.

We will only need the BaseFunc and their checked_type to decide the calling convention
across the function variants.
  • Loading branch information
tqchen committed Jan 10, 2020
1 parent d6a23cf commit 2f967ba
Show file tree
Hide file tree
Showing 16 changed files with 355 additions and 193 deletions.
53 changes: 1 addition & 52 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_

#include <tvm/ir/expr.h>
#include <string>
#include <algorithm>
#include <unordered_map>
Expand All @@ -37,58 +38,6 @@

namespace tvm {

/*!
* \brief Base node of all primitive expressions.
*
* A primitive expression deals with low-level
* POD data types and handles without
* doing life-cycle management for objects.
*
* PrimExpr is used in the low-level code
* optimizations and integer analysis.
*
* \sa PrimExpr
*/
class PrimExprNode : public Object {
public:
/*! \brief The data type of the expression. */
DataType dtype;

static constexpr const char* _type_key = "PrimExpr";
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, Object);
};

/*!
* \brief Container of all primitive expressions.
* \sa PrimExprNode
*/
class PrimExpr : public ObjectRef {
public:
PrimExpr() {}
explicit PrimExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)

/*! \return the data type of this expression. */
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}

using ContainerType = PrimExprNode;
};

/*! \brief Base node of all statements. */
class StmtNode : public Object {
Expand Down
270 changes: 270 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/expr.h
* \brief Base expr nodes in TVM.
*/
#ifndef TVM_IR_EXPR_H_
#define TVM_IR_EXPR_H_

#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/ir/span.h>
#include <tvm/ir/type.h>
#include <string>

namespace tvm {

/*!
* \brief Base type of all the expressions.
* \sa Expr
*/
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};

/*!
* \brief Managed reference to BaseExprNode.
* \sa BaseExprNode
*/
class BaseExpr : public ObjectRef {
public:
/*! \brief Cosntructor */
BaseExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit BaseExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief The container type. */
using ContainerType = BaseExprNode;
};

/*!
* \brief Base node of all primitive expressions.
*
* A primitive expression deals with low-level
* POD data types and handles without
* doing life-cycle management for objects.
*
* PrimExpr is used in the low-level code
* optimizations and integer analysis.
*
* \sa PrimExpr
*/
class PrimExprNode : public BaseExprNode {
public:
/*!
* \brief The runtime data type of the primitive expression.
*
* runtime::DataType(dtype) provides coarse grained type information
* during compile time and runtime. It is eagerly built in
* PrimExpr expression construction and can be used for
* quick type checking.
*
* dtype is sufficient to decide the Type of the PrimExpr
* when it corresponds to POD value types such as i32.
*
* When dtype is DataType::Handle(), the expression could corresponds to
* a more fine-grained Type, and we can get the type by running lazy type inference.
*/
DataType dtype;

static constexpr const char* _type_key = "PrimExpr";
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};

/*!
* \brief Reference to PrimExprNode.
* \sa PrimExprNode
*/
class PrimExpr : public BaseExpr {
public:
/*! \brief Cosntructor */
PrimExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit PrimExpr(ObjectPtr<Object> ptr) : BaseExpr(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)

/*! \return the data type of this expression. */
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}
/*! \brief The container type. */
using ContainerType = PrimExprNode;
};

/*!
* \brief Base node of all non-primitive expressions.
*
* RelayExpr supports tensor types, functions and ADT as
* first class citizens. The life-cycle of the corresponding
* objects are implicitly managed by the language.
*
* \sa RelayExpr
*/
class RelayExprNode : public BaseExprNode {
public:
/*!
* \brief Span that points to the original source code.
* Reserved debug information.
*/
mutable Span span;
/*!
* \brief Stores the result of type inference(type checking).
*
* \note This can be undefined before type inference.
* This value is discarded during serialization.
*/
mutable Type checked_type_ = Type(nullptr);
/*!
* \return The checked_type
*/
const Type& checked_type() const;
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;

static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
};

/*!
* \brief Managed reference to RelayExprNode.
* \sa RelayExprNode
*/
class RelayExpr : public BaseExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode);
};

class GlobalVar;
/*!
* \brief Global variable that leaves in the top-level module.
*
* A GlobalVar only refers to function definitions.
* This is used to enable recursive calls between function.
*
* \sa GlobalVarNode
*/
class GlobalVarNode : public RelayExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

static constexpr const char* _type_key = "relay.GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};

/*!
* \brief Managed reference to GlobalVarNode.
* \sa GlobalVarNode
*/
class GlobalVar : public RelayExpr {
public:
TVM_DLL explicit GlobalVar(std::string name_hint);

TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};

/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions shares the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};

/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

// implementataions
inline const Type& RelayExprNode::checked_type() const {
CHECK(checked_type_.defined())
<< "internal error: the type checker has "
<< "not populated the checked_type "
<< "field for "
<< GetRef<RelayExpr>(this);
return this->checked_type_;
}

template<typename TTypeNode>
inline const TTypeNode* RelayExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->GetTypeKey();
return node;
}

} // namespace tvm
#endif // TVM_IR_EXPR_H_
2 changes: 1 addition & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
*
* ## Relation between Type and runtime::DataType
*
* Besides Type, we also store a dtype field in some of the low-level IR's Expr.
* Besides Type, we also store a dtype field in the low-level PrimExpr.
* runtime::DataType(dtype) provides coarse grained type information
* during compile time and runtime. It is eagerly built in
* low-level expression construction and can be used for
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class ConstructorNode : public ExprNode {

class Constructor : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, Expr, ConstructorNode);
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
};

/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
Expand Down Expand Up @@ -306,7 +306,7 @@ class MatchNode : public ExprNode {

class Match : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Match, Expr, MatchNode);
TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
};

} // namespace relay
Expand Down
Loading

0 comments on commit 2f967ba

Please sign in to comment.