Skip to content

Commit

Permalink
[REFACTOR][IR] Initialize Unified IR Expr Data Structure (apache#4673)
Browse files Browse the repository at this point in the history
This PR moves a few base types from relay and low-level Expr into the ir sub-folder.
These classes will serve as a common type system across the stack.

Rationale:

- PrimExpr for low-level expressions
- RelayExpr for advanced features, including Function definition.
- Introduce BaseFunc to host all functions, including future PrimFunc(low-level expr functions, subject to discussion).

This is a minimum change we can do to unify the classes into a common hierarchy.
The main data structure that are variant specific will still be kept in the sub-namespaces.
We only include classes that is needed to allow a common Module class.
- BaseFunc
- GlobalVar
- Type definition part of ADT

We will only need the BaseFunc and their checked_type to decide the calling convention
across the function variants.
  • Loading branch information
tqchen authored and zhiics committed Mar 2, 2020
1 parent 4ae817a commit 6e5cc4c
Show file tree
Hide file tree
Showing 21 changed files with 587 additions and 326 deletions.
53 changes: 1 addition & 52 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_

#include <tvm/ir/expr.h>
#include <string>
#include <algorithm>
#include <unordered_map>
Expand All @@ -37,58 +38,6 @@

namespace tvm {

/*!
* \brief Base node of all primitive expressions.
*
* A primitive expression deals with low-level
* POD data types and handles without
* doing life-cycle management for objects.
*
* PrimExpr is used in the low-level code
* optimizations and integer analysis.
*
* \sa PrimExpr
*/
class PrimExprNode : public Object {
public:
/*! \brief The data type of the expression. */
DataType dtype;

static constexpr const char* _type_key = "PrimExpr";
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, Object);
};

/*!
* \brief Container of all primitive expressions.
* \sa PrimExprNode
*/
class PrimExpr : public ObjectRef {
public:
PrimExpr() {}
explicit PrimExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)

/*! \return the data type of this expression. */
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}

using ContainerType = PrimExprNode;
};

/*! \brief Base node of all statements. */
class StmtNode : public Object {
Expand Down
142 changes: 142 additions & 0 deletions include/tvm/ir/adt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/adt.h
* \brief Algebraic data type definitions.
*
* We adopt relay's ADT definition as a unified class
* for decripting structured data.
*/
#ifndef TVM_IR_ADT_H_
#define TVM_IR_ADT_H_

#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
#include <string>

namespace tvm {

/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
* \sa Constructor
*/
class ConstructorNode : public RelayExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int32_t tag = -1;

ConstructorNode() {}

void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};

/*!
* \brief Managed reference to ConstructorNode
* \sa ConstructorNode
*/
class Constructor : public RelayExpr {
public:
/*!
* \brief Constructor
* \param name_hint the name of the constructor.
* \param inputs The input types.
* \param belong_to The data type var the constructor will construct.
*/
TVM_DLL Constructor(std::string name_hint,
Array<Type> inputs,
GlobalTypeVar belong_to);

TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
};

/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
Array<TypeVar> type_vars;
/*! \brief The constructors. */
Array<Constructor> constructors;

void VisitAttrs(AttrVisitor* v) {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};

/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData : public Type {
public:
/*!
* \brief Constructor
* \param header the name of ADT.
* \param type_vars type variables.
* \param constructors constructors field.
*/
TVM_DLL TypeData(GlobalTypeVar header,
Array<TypeVar> type_vars,
Array<Constructor> constructors);

TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
};

} // namespace tvm
#endif // TVM_IR_ADT_H_
Loading

0 comments on commit 6e5cc4c

Please sign in to comment.