Skip to content

Commit

Permalink
[REFACTOR][TYPE] Finish move all types to IR. (#4746)
Browse files Browse the repository at this point in the history
* [REFACTOR][TYPE] Finish move all types to IR.

- Move definition of Ref and TensorType to ir
- Move type_functor.h to public header.
- Rename RefType -> RelayRefType for clarity.

* Add atol
  • Loading branch information
tqchen authored Jan 20, 2020
1 parent ee0af84 commit 2c0c184
Show file tree
Hide file tree
Showing 56 changed files with 486 additions and 408 deletions.
117 changes: 117 additions & 0 deletions include/tvm/ir/tensor_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/tensor_type.h
* \brief Polymorphic tensor types.
*/
#ifndef TVM_IR_TENSOR_TYPE_H_
#define TVM_IR_TENSOR_TYPE_H_

#include <tvm/ir/type.h>
#include <tvm/ir/expr.h>

namespace tvm {
/*!
* \brief Base of all Tensor types
* This container can hold TensorType or GenericTensorType.
* \sa BaseTensorType, TensorTypeNode
*/
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
};

/*!
* \brief Managed reference to BaseTensorTypeNode.
* \sa BaseTensorTypeNode.
*/
class BaseTensorType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
};

/*!
* \brief This is the most commonly used type in relay.
* TensorType have a fixed dimension, data type.
*
* The elements of shape can be either IntImm(constant integer),
* or any symbolic integer expression.
* The symbolic integer allows generic shape inference in certain cases.
* \sa TensorType
*/
class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by PrimExpr(tvm::Expr).
*/
Array<PrimExpr> shape;
/*! \brief The content data type */
DataType dtype;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}

/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
TVM_DLL PrimExpr Size() const;

static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
};

/*!
* \brief Managed reference to TensorTypeNode.
* \sa TensorTypeNode.
*/
class TensorType : public Type {
public:
/*!
* \brief Constructor.
* \param shape The shape of the tensor.
* \param dtype The runtime dtype of the tensor's elements.
*/
TVM_DLL TensorType(Array<PrimExpr> shape, DataType dtype);

/*!
* \brief Construct an scalar containing elements of dtype.
* \param dtype The runtime dtype of the tensor's elements.
* \return THe constructed type.
*/
TVM_DLL static TensorType Scalar(DataType dtype);

TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
};

// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
class GenericTensorType;
// stores a DataType.
class GenericDataType;
// stores a DataType.
class GenericShape;

} // namespace tvm
#endif // TVM_IR_TENSOR_TYPE_H_
70 changes: 70 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,5 +352,75 @@ class FuncType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};

/*!
* \brief Intermediate values that is used to indicate incomplete type
* during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeVar represents the input to the graph.
*
* \sa IncompleteType
*/
class IncompleteTypeNode : public TypeNode {
public:
/*! \brief kind of the type. */
TypeKind kind;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("kind", &kind);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};

/*!
* \brief Managed reference to IncompleteTypeNode.
* \sa IncompleteTypeNode
*/
class IncompleteType : public Type {
public:
/*!
* \brief Constructor.
* \param kind kind of the type.
*/
TVM_DLL explicit IncompleteType(TypeKind kind);

TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
};


/*!
* \brief Reference Type High-level Relay IR.
*
* \sa RelayRefType.
*/
class RelayRefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;

RelayRefTypeNode() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
};

/*!
* \brief Managed reference to RelayRefTypeNode.
* \sa RelayRefTypeNode.
*/
class RelayRefType : public Type {
public:
TVM_DLL explicit RelayRefType(Type value);
TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode);
};
} // namespace tvm
#endif // TVM_IR_TYPE_H_
58 changes: 33 additions & 25 deletions src/relay/ir/type_functor.h → include/tvm/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
*/

/*!
* \file type_functor.h
* \file tvm/ir/type_functor.h
* \brief A way to defined arbitrary function signature with dispatch on types.
*/
#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
#define TVM_RELAY_IR_TYPE_FUNCTOR_H_
#ifndef TVM_IR_TYPE_FUNCTOR_H_
#define TVM_IR_TYPE_FUNCTOR_H_

#include <tvm/node/functor.h>
#include <tvm/relay/expr.h>
Expand All @@ -32,17 +32,16 @@
#include <utility>

namespace tvm {
namespace relay {

template <typename FType>
class TypeFunctor;

// functions to be overriden.
#define TYPE_FUNCTOR_DEFAULT \
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }


#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitType_(static_cast<const OP*>(n.get()), \
Expand Down Expand Up @@ -89,10 +88,11 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RelayRefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw; // unreachable, written to stop compiler warning
Expand All @@ -103,40 +103,48 @@ class TypeFunctor<R(const Type& n, Args...)> {
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
TVM_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
TVM_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(RelayRefTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
return vtable;
}
};

#undef TVM_TYPE_FUNCTOR_DISPATCH

/*!
* \brief A type visitor that recursively visit types.
*/
class TypeVisitor : public TypeFunctor<void(const Type& n)> {
class TVM_DLL TypeVisitor :
public TypeFunctor<void(const Type& n)> {
public:
void VisitType_(const TypeVarNode* op) override;
void VisitType_(const IncompleteTypeNode* op) override;
void VisitType_(const TensorTypeNode* op) override;
void VisitType_(const FuncTypeNode* op) override;
void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override;
void VisitType_(const RefTypeNode* op) override;
void VisitType_(const RelayRefTypeNode* op) override;
void VisitType_(const GlobalTypeVarNode* op) override;
void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override;
void VisitType_(const PrimTypeNode* op) override;
};

// Mutator that transform a type to another one.
class TypeMutator : public TypeFunctor<Type(const Type& n)> {
/*!
* \brief TypeMutator that mutates expressions.
*/
class TVM_DLL TypeMutator :
public TypeFunctor<Type(const Type& n)> {
public:
Type VisitType(const Type& t) override;
Type VisitType_(const TypeVarNode* op) override;
Expand All @@ -145,10 +153,11 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
Type VisitType_(const FuncTypeNode* op) override;
Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override;
Type VisitType_(const RefTypeNode* op) override;
Type VisitType_(const RelayRefTypeNode* op) override;
Type VisitType_(const GlobalTypeVarNode* op) override;
Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override;
Type VisitType_(const PrimTypeNode* op) override;

private:
Array<Type> MutateArray(Array<Type> arr);
Expand All @@ -161,6 +170,5 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
*/
Type Bind(const Type& type, const Map<TypeVar, Type>& args_map);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_
#endif // TVM_IR_TYPE_FUNCTOR_H_
Loading

0 comments on commit 2c0c184

Please sign in to comment.