Skip to content

Commit

Permalink
Relax IRBuilder (#4)
Browse files Browse the repository at this point in the history
* Add initial IRBuilder.

* Add function output to irbuilder; update based on new AST.

* Add call method; clean up bindings

* Add test.

* Add multifuction test

* Move implementation to C++; infer shape and type

* update op python hook

* More tests and bug fix

* Add comments.

* Update shape/type inference.

* Restructure code; add python type hint.

* Cleanup code.

* Rebase; address comments.

* Add call intrinsic.

* nits.

* Remove call op.

* Migrate scope to C++ using tvm::With.

* Address naming.

* Add GetBlocks API.

* Unify EmitOutput APIs; add more comments.

* Remove shape and type deduction code.

* Also remove the shape/type attr interface.

* Address comments.

* Differentiate global and local function.

* Reset counter after building func/block.

* Rebase.

* Remove shape infer builtin.

* Return from void function as empty tuple.

Co-authored-by: Michalis Papadimitriou <[email protected]>
  • Loading branch information
YuchenJin and mikepapadim committed Nov 17, 2022
1 parent 2340269 commit 58a5a89
Show file tree
Hide file tree
Showing 15 changed files with 789 additions and 11 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ class TupleType : public Type {
inline Type VoidType() { return TupleType::Empty(); }

/*!
* \brief Check whether the tyep represents void.
* \brief Check whether the type represents void.
* \return The check result.
*/
inline bool IsVoidType(const Type& type) {
Expand Down
197 changes: 197 additions & 0 deletions include/tvm/relax/ir_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/ir_builder.h
* \brief The utility for constructing Relax AST.
*/
#ifndef TVM_RELAX_IR_BUILDER_H_
#define TVM_RELAX_IR_BUILDER_H_

#include <tvm/ir/expr.h>
#include <tvm/relax/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/with.h>

namespace tvm {
namespace relax {

using relay::Call;

class IRBuilder;

/*!
* \brief The state of Relax function node being built.
*/
struct RelaxFunction {
/*! \brief The function name. */
Optional<GlobalVar> func_name = NullOpt;
/*! \brief The function parameters. */
Array<Var> params;
/*! \brief The bindings in the function. */
std::vector<Binding> bindings;
/*! \brief The binding blocks in the function. */
std::vector<BindingBlock> binding_blocks;
/*! \brief The return of the function. */
Expr ret = Tuple();
/*! \brief The FunctionNode being built. */
Function func;
};

/*!
* \brief A builder that provides APIs to build Relax AST.
*/
class IRBuilderNode : public Object {
public:
/*!
* \brief Fill the function name and parameters.
*/
void FillFuncNameParam(const Array<Var>& params, const std::string& func_name);
/*!
* \brief Build a function node.
*/
void BuildFunction();
/*!
* \brief Build a binding block.
*/
void BuildBlock();
/*!
* \brief Emit a call node.
* \param call The CallNode to be emitted.
* \return The variable being created and binded to \p call.
*/
Var Emit(const Call& call);
/*!
* \brief Generate an output for the current dataflow block or function.
* \param output The output variable of the block/function.
* \return The variable being binded to \p ouput.
*/
Var EmitOutput(const Expr& output);
/*!
* \brief Get the function being built.
*/
Function Get();
/*!
* \brief Get binding blocks being built.
*/
std::vector<BindingBlock> GetBlocks();
/*!
* \brief Create a IRBuilder.
* \return The created IRBuilder.
*/
TVM_DLL static IRBuilder Create();

void VisitAttrs(AttrVisitor* v) {}

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.IRBuilder";
TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, Object);

private:
/*! \brief The state of the function currently being built. */
RelaxFunction func;
/*! \brief A flag tracking if currently inside a dataflow block or not. */
bool is_dataflow = false;
/*! \brief A global variable counter for naming global variables. */
int global_var_counter = 0;
/*! \brief A dataflow variable counter for naming dataflow variables. */
int dataflow_var_counter = 0;
};

class IRBuilder : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
};

/*! \brief Auxiliary scope for building Relax function node,
* similar to python's with syntax.
*
* \code
* {
* With<FunctionScope> scope(ir_builder);
* // build function node.
* }
*/
class FunctionScopeNode : public Object {
public:
IRBuilder ir_builder;
void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); }

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.FunctionScope";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionScopeNode, Object);
};

class FunctionScope : public ObjectRef {
public:
TVM_DLL FunctionScope(IRBuilder ib);
TVM_DEFINE_OBJECT_REF_METHODS(FunctionScope, ObjectRef, FunctionScopeNode);
class Internal;

private:
// Classes to get the Python `with` like syntax.
friend class Internal;
friend class With<FunctionScope>;
// The entry of a function scope.
TVM_DLL void EnterWithScope();
// The exit of a function scope.
TVM_DLL void ExitWithScope();
};

/*! \brief Auxiliary scope for building Relax dataflow block,
* similar to python's with syntax.
*
* \code
* {
* With<DataflowScope> scope(ir_builder);
* // build dataflow block.
* }
*/
class DataflowScopeNode : public Object {
public:
IRBuilder ir_builder;
void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); }

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.DataflowScope";
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowScopeNode, Object);
};

class DataflowScope : public ObjectRef {
public:
TVM_DLL DataflowScope(IRBuilder ib);
TVM_DEFINE_OBJECT_REF_METHODS(DataflowScope, ObjectRef, DataflowScopeNode);
class Internal;

private:
// Classes to get the Python `with` like syntax.
friend class Internal;
friend class With<DataflowScope>;
// The entry of a dataflow scope.
TVM_DLL void EnterWithScope();
// The exit of a dataflow scope.
TVM_DLL void ExitWithScope();
};

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_IR_BUILDER_H_
11 changes: 6 additions & 5 deletions include/tvm/relax/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ namespace relax {

class ShapeTypeNode : public TypeNode {
public:

void VisitAttrs(tvm::AttrVisitor* v) {
}
void VisitAttrs(tvm::AttrVisitor* v) {}

bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
return true;
Expand All @@ -64,10 +62,9 @@ class ShapeType : public Type {
const ShapeTypeNode* get() const {
return operator->();
}
using ContainerType = ShapeTypeNode;
using ContainerType = ShapeTypeNode;
};


class DynTensorTypeNode : public BaseTensorTypeNode {
public:
/*!
Expand All @@ -92,6 +89,10 @@ class DynTensorTypeNode : public BaseTensorTypeNode {
hash_reduce(dtype);
}

inline bool IsUnknownRank() const { return rank == -1; }

inline bool IsUnknownDtype() const { return dtype.is_void(); }

static constexpr const char* _type_key = "relax.DynTensorType";
TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode);
};
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from . import ty
from . import vm
from . import op
from . import ir_builder
from . import op


# Expr
Expand Down Expand Up @@ -56,3 +58,6 @@

# Operator
from .op.base import call_dps

# IRBuilder
IRBuilder = ir_builder.IRBuilder
Loading

0 comments on commit 58a5a89

Please sign in to comment.