Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][TYPE] Finish move all types to IR. #4746

Merged
merged 2 commits into from
Jan 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions include/tvm/ir/tensor_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/tensor_type.h
* \brief Polymorphic tensor types.
*/
#ifndef TVM_IR_TENSOR_TYPE_H_
#define TVM_IR_TENSOR_TYPE_H_

#include <tvm/ir/type.h>
#include <tvm/ir/expr.h>

namespace tvm {
/*!
* \brief Base of all Tensor types
* This container can hold TensorType or GenericTensorType.
* \sa BaseTensorType, TensorTypeNode
*/
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
};

/*!
* \brief Managed reference to BaseTensorTypeNode.
* \sa BaseTensorTypeNode.
*/
class BaseTensorType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
};

/*!
* \brief This is the most commonly used type in relay.
* TensorType have a fixed dimension, data type.
*
* The elements of shape can be either IntImm(constant integer),
* or any symbolic integer expression.
* The symbolic integer allows generic shape inference in certain cases.
* \sa TensorType
*/
class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by PrimExpr(tvm::Expr).
*/
Array<PrimExpr> shape;
/*! \brief The content data type */
DataType dtype;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}

/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
TVM_DLL PrimExpr Size() const;

static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
};

/*!
* \brief Managed reference to TensorTypeNode.
* \sa TensorTypeNode.
*/
class TensorType : public Type {
public:
/*!
* \brief Constructor.
* \param shape The shape of the tensor.
* \param dtype The runtime dtype of the tensor's elements.
*/
TVM_DLL TensorType(Array<PrimExpr> shape, DataType dtype);

/*!
* \brief Construct an scalar containing elements of dtype.
* \param dtype The runtime dtype of the tensor's elements.
* \return THe constructed type.
*/
TVM_DLL static TensorType Scalar(DataType dtype);

TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
};

// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
class GenericTensorType;
// stores a DataType.
class GenericDataType;
// stores a DataType.
class GenericShape;

} // namespace tvm
#endif // TVM_IR_TENSOR_TYPE_H_
70 changes: 70 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,5 +352,75 @@ class FuncType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};

/*!
* \brief Intermediate values that is used to indicate incomplete type
* during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeVar represents the input to the graph.
*
* \sa IncompleteType
*/
class IncompleteTypeNode : public TypeNode {
public:
/*! \brief kind of the type. */
TypeKind kind;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("kind", &kind);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};

/*!
* \brief Managed reference to IncompleteTypeNode.
* \sa IncompleteTypeNode
*/
class IncompleteType : public Type {
public:
/*!
* \brief Constructor.
* \param kind kind of the type.
*/
TVM_DLL explicit IncompleteType(TypeKind kind);

TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
};


/*!
* \brief Reference Type High-level Relay IR.
*
* \sa RelayRefType.
*/
class RelayRefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;

RelayRefTypeNode() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
};

/*!
* \brief Managed reference to RelayRefTypeNode.
* \sa RelayRefTypeNode.
*/
class RelayRefType : public Type {
public:
TVM_DLL explicit RelayRefType(Type value);
TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode);
};
} // namespace tvm
#endif // TVM_IR_TYPE_H_
58 changes: 33 additions & 25 deletions src/relay/ir/type_functor.h → include/tvm/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
*/

/*!
* \file type_functor.h
* \file tvm/ir/type_functor.h
* \brief A way to defined arbitrary function signature with dispatch on types.
*/
#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
#define TVM_RELAY_IR_TYPE_FUNCTOR_H_
#ifndef TVM_IR_TYPE_FUNCTOR_H_
#define TVM_IR_TYPE_FUNCTOR_H_

#include <tvm/node/functor.h>
#include <tvm/relay/expr.h>
Expand All @@ -32,17 +32,16 @@
#include <utility>

namespace tvm {
namespace relay {

template <typename FType>
class TypeFunctor;

// functions to be overriden.
#define TYPE_FUNCTOR_DEFAULT \
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }


#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitType_(static_cast<const OP*>(n.get()), \
Expand Down Expand Up @@ -89,10 +88,11 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RelayRefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw; // unreachable, written to stop compiler warning
Expand All @@ -103,40 +103,48 @@ class TypeFunctor<R(const Type& n, Args...)> {
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
TVM_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
TVM_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(RelayRefTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
return vtable;
}
};

#undef TVM_TYPE_FUNCTOR_DISPATCH

/*!
* \brief A type visitor that recursively visit types.
*/
class TypeVisitor : public TypeFunctor<void(const Type& n)> {
class TVM_DLL TypeVisitor :
public TypeFunctor<void(const Type& n)> {
public:
void VisitType_(const TypeVarNode* op) override;
void VisitType_(const IncompleteTypeNode* op) override;
void VisitType_(const TensorTypeNode* op) override;
void VisitType_(const FuncTypeNode* op) override;
void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override;
void VisitType_(const RefTypeNode* op) override;
void VisitType_(const RelayRefTypeNode* op) override;
void VisitType_(const GlobalTypeVarNode* op) override;
void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override;
void VisitType_(const PrimTypeNode* op) override;
};

// Mutator that transform a type to another one.
class TypeMutator : public TypeFunctor<Type(const Type& n)> {
/*!
* \brief TypeMutator that mutates expressions.
*/
class TVM_DLL TypeMutator :
public TypeFunctor<Type(const Type& n)> {
public:
Type VisitType(const Type& t) override;
Type VisitType_(const TypeVarNode* op) override;
Expand All @@ -145,10 +153,11 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
Type VisitType_(const FuncTypeNode* op) override;
Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override;
Type VisitType_(const RefTypeNode* op) override;
Type VisitType_(const RelayRefTypeNode* op) override;
Type VisitType_(const GlobalTypeVarNode* op) override;
Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override;
Type VisitType_(const PrimTypeNode* op) override;

private:
Array<Type> MutateArray(Array<Type> arr);
Expand All @@ -161,6 +170,5 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
*/
Type Bind(const Type& type, const Map<TypeVar, Type>& args_map);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_
#endif // TVM_IR_TYPE_FUNCTOR_H_
Loading