Skip to content

Commit

Permalink
ExprMutator refactor & Normalizer (apache#32)
Browse files Browse the repository at this point in the history
* fixes

* revert checked_type visitor and fix relax usage

* ExprNormalizer

* fix that annoying bug and get tests passing

* Memoization fix for the ExprMutator; separate VisitVarDef from use.

* rebase.

* rebase.

* address part of comments.

* address more comments

* address more comments and add doc

* address more comments

* fix potential mutation bug

* always assign normalized shape if can

* address comments

Co-authored-by: Altan Haan <[email protected]>
  • Loading branch information
2 people authored and junrushao committed Oct 14, 2022
1 parent 99401ac commit 35d1ae3
Show file tree
Hide file tree
Showing 14 changed files with 619 additions and 209 deletions.
61 changes: 46 additions & 15 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,39 +44,43 @@ class BlockBuilder;
*/
class BlockBuilderNode : public Object {
public:
BlockBuilderNode(std::shared_ptr<NameTable> name_table) : name_table_(name_table) {}
BlockBuilderNode();

~BlockBuilderNode();

BlockBuilderNode() { name_table_ = std::make_shared<NameTable>(); }

/*! \brief Begin to build a DataflowBlock. */
void BeginDataflowBlock();

/*! \brief Begin to build a BindingBlock. */
void BeginBindingBlock();

/*!
* \brief End building a BindingBlock.
* \return The BindingBlock being built.
*/
BindingBlock EndBlock();

/*!
* \brief Check if the block being built is DataflowBlock or not.
* \return A boolean that indicates if the block being built is DataflowBlock or not.
*/
inline bool CurrentBlockIsDataFlow() { return CurrentFrame()->is_dataflow; }

/*!
* \brief Emits an Expr, and returns the variable it is bound to.
* \param expr The Expr to be emitted.
* \param name_hint Name hint for the bound variable.
* \return The new variable that \p expr is bound to.
*/
virtual Var Emit(const Expr& expr, std::string name_hint = "");

/*!
* \brief Emits a variable binding, and returns the bound Var.
* \param binding The variable binding.
* \return The bound variable.
*/
virtual Var Emit(const VarBinding& binding);

/*!
* \brief Emit a MatchShape.
* \param value The value of the MatchShape to be emitted.
Expand All @@ -85,49 +89,57 @@ class BlockBuilderNode : public Object {
* \return The variable bound to the MatchShape.
*/
Var EmitMatchShape(const Expr& value, const Array<PrimExpr>& pattern, std::string name_hint = "");

/*!
* \brief Emit a MatchShape binding.
* \param binding The MatchShape binding to be emitted.
* \return The variable bound to the MatchShape.
*/
Var EmitMatchShape(const MatchShape& binding);

/*!
* \brief Generate an output for the current dataflow block.
* \param output The output variable of the block.
* \param name_hint Name hint for the bound variable.
* \return The variable bound to \p output.
*/
Var EmitOutput(const Expr& output, std::string name_hint = "");

/*!
* \brief Generate an output for the current dataflow block.
* \param binding The output binding to output.
* \return The variable bound to \p output.
*/
Var EmitOutput(const VarBinding& binding);

/*!
* \brief Lookup a var in the binding table \p var_map_.
* \brief Lookup a var in the binding table \p binding_table_.
* \param var The input var.
* \return The Expr bound to the input \p var.
*/
Expr LookupVar(const Var& var);
Expr LookupBinding(const Var& var);

/*!
* \brief Check if two shape expressions can be proven equal at compile time.
* \param lhs The input lhs shape.
* \param rhs The input rhs shape.
* \return Whether we can prove lhs shape is the same as the rhs shape.
*/
bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs);

/*!
* \brief Normalize an Expr to complete its shape and type.
* \param expr The input expr.
* \return The expr with normalized shape and type.
* \brief Convert an expression to A-normal form, and try to eagerly infer types and shapes.
* \param expr The input expression.
* \return The normalized expression.
*/
Expr Normalize(const Expr& expr);

/*!
* \brief Create a BlockBuilder.
* \return The created BlockBuilder.
* \brief Get the name table for generating unique names.
*
* \return The name table.
*/
TVM_DLL static BlockBuilder Create();
NameTable* name_table();

void VisitAttrs(AttrVisitor* v) {}

Expand All @@ -150,26 +162,45 @@ class BlockBuilderNode : public Object {
Array<Binding> bindings;
bool is_dataflow;
};

/*!
* \brief Utility class for performing IR normalization (conversion to ANF, eager forward shape
* and type inference).
*/
class ExprNormalizer;

friend class BlockBuilder;

/*!
* \brief Get the current block frame.
* \return The current block frame.
*/
BlockFrame* CurrentFrame();

/*! \brief A stack to store block frames. */
std::stack<BlockFrame> block_stack_;

/*! \brief A diagnostic context for reporting errors. */
DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {}));

/*! \brief A binding table that maps var to value. */
// TODO(@yuchen, @altanh): make var_map_ scoped, and decide if it should be in the builder
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_map_;
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> binding_table_;

/*! \brief A name table to get unique names for IR construction. */
std::shared_ptr<NameTable> name_table_;
std::unique_ptr<NameTable> name_table_;

/*! \brief The internal normalizer used for ANF conversion. */
std::unique_ptr<ExprNormalizer> normalizer_;
};

class BlockBuilder : public ObjectRef {
public:
TVM_DLL explicit BlockBuilder(std::shared_ptr<NameTable> name_table);
/*!
* \brief Create a BlockBuilder.
* \return The created BlockBuilder.
*/
TVM_DLL static BlockBuilder Create();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode);
};

Expand Down
10 changes: 10 additions & 0 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ShapeExpr : public Expr {
public:
TVM_DLL explicit ShapeExpr(Array<PrimExpr> values, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode);
};

/*! \brief The variable class for all Relax bindings. */
Expand Down Expand Up @@ -131,6 +132,7 @@ class Var : public Expr {
TVM_DLL explicit Var(Id vid, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
};

/*! \brief A sub-type of the variable node used to mark dataflow variables from
Expand Down Expand Up @@ -175,6 +177,7 @@ class DataflowVar : public Var {
runtime::Optional<Type> type_annotation, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode);
};

/*! \brief The base class of a variable binding in Relax. */
Expand Down Expand Up @@ -235,6 +238,7 @@ class MatchShape : public Binding {
TVM_DLL explicit MatchShape(Expr value, Array<PrimExpr> pattern,
Var var, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchShapeNode);
};

class VarBinding;
Expand Down Expand Up @@ -266,6 +270,7 @@ class VarBinding : public Binding {
public:
TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode);
};

class BindingBlock;
Expand Down Expand Up @@ -296,6 +301,7 @@ class BindingBlock : public ObjectRef {
public:
TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode);
};

class DataflowBlock;
Expand All @@ -315,6 +321,7 @@ class DataflowBlock : public BindingBlock {
public:
TVM_DLL explicit DataflowBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode);
};

/*! \brief A sequence of blocks followed by an expression.
Expand Down Expand Up @@ -356,6 +363,7 @@ class SeqExpr : public Expr {
public:
TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode);
};

/*! \brief A Relax function, eventually to replace the current Relay function definition. */
Expand Down Expand Up @@ -411,6 +419,7 @@ class Function : public Expr {
TVM_DLL explicit Function(runtime::Optional<GlobalVar> name, Array<Var> params, Expr body,
Type ret_type, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
};

/*! \brief The extern function, which can represent packed function. */
Expand Down Expand Up @@ -440,6 +449,7 @@ class ExternFunc : public Expr {
public:
TVM_DLL ExternFunc(String global_symbol, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode);
};

} // namespace relax
Expand Down
82 changes: 43 additions & 39 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,7 @@ void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
public:
ExprMutator() {
name_table_ = std::make_shared<NameTable>();
builder_ = BlockBuilder(name_table_);
}

/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) {
return this->VisitExpr(expr);
builder_ = BlockBuilder::Create();
}

Expr VisitExpr(const Expr& expr) override;
Expand Down Expand Up @@ -218,47 +209,60 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
virtual void VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);

/*!
* \brief Rewrite the var definition site.
* \param var The var to be visited.
* \return The var after post-order rewritten.
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
*/
virtual Var VisitVarDef(const Var& var);

virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);

protected:
Expr MutateWithPrologue(const Expr& expr, bool is_dataflow);
class ExprNormalizer;

/*! \brief Look up the value of a variable. If the variable is bound, then returns the bound
* value. Otherwise, returns the rewritten expression for the variable.
/*!
* \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If.
* \param expr The expr to be visited.
* \return The expr after visiting.
*/
Expr LookupVar(Var var);
Expr VisitWithNewScope(const Expr& expr);

inline void UpdateMemo(Expr pre, Expr post) {
if (const VarNode* var = pre.as<VarNode>()) {
var_memo_[var->vid] = post;
} else {
expr_memo_[pre] = post;
}
}
/*!
* \brief Look up the value bound to a variable.
* \param var The var to be looked up.
* \return The value bound to the input \p var.
*/
Expr LookupBinding(const Var& var);

inline Optional<Expr> LookupMemo(Expr pre) {
if (pre.as<VarNode>()) {
Id vid = Downcast<Var>(pre)->vid;
if (var_memo_.count(vid)) {
return var_memo_[vid];
}
} else {
if (expr_memo_.count(pre)) {
return expr_memo_[pre];
}
}
return NullOpt;
/*!
* \brief Post-order rewrite a node and normalize.
* \param T The node type to be rewritten.
* \param op The node to be rewritten.
* \return The node after post rewritten.
*/
template <typename T>
Expr VisitExprPostOrder_(const T* op) {
return builder_->Normalize(ExprMutator::VisitExpr_(op));
}

/*! \brief Variable memoization table using Id equality */
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;

/*! \brief Expr memoization table using pointer equality */
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> expr_memo_;
/*!
* \brief Create a new var with specified shape and type if it's original shape or type does not
* match with the specified ones.
* \param var The var to be updated.
* \param shape The specified shape.
* \param type The specified type.
* \return The var filled with \p shape and \p type.
*/
Var WithShapeAndType(Var var, Optional<ObjectRef> shape, Type type);

std::shared_ptr<NameTable> name_table_;
/*! \brief Internal block builder to emit bindings during rewriting. */
BlockBuilder builder_;

/*! \brief Remap a var to a new var in use-site. */
std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
};

// TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ def vm_shape_lower(mod: IRModule) -> IRModule:
The input module.
"""
return _ffi_api.vm_shape_lower(mod)


def to_anf(mod: IRModule):
return _ffi_api.to_anf(mod)
12 changes: 8 additions & 4 deletions src/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) {

Doc RelaxScriptPrinter::VisitNode_(const relax::SeqExprNode* op) {
Doc doc;
int i = 0;
for (const relax::BindingBlock& block : op->blocks) {
doc << "# block " << i++ << Doc::NewLine();
doc << Print(block);
}
// NOTE: the body expression is printed in the parent, since SeqExprs are used for both Function
Expand Down Expand Up @@ -484,11 +486,13 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function&
}
doc << ":" << Doc::NewLine(4);

const relax::SeqExprNode* body = func->body.as<relax::SeqExprNode>();
ICHECK(body) << "in the Relax IR normal form, the body of a Function should be a SeqExpr";
if (const relax::SeqExprNode* body = func->body.as<relax::SeqExprNode>()) {
doc << Doc::Indent(4, Print(func->body));
doc << Doc::Indent(4, Doc::Text("return ") << Print(body->body)) << Doc::NewLine();
} else {
doc << Doc::Indent(4, Doc::Text("return ") << Print(func->body)) << Doc::NewLine();
}

doc << Doc::Indent(4, Print(func->body));
doc << Doc::Indent(4, Doc::Text("return ") << Print(body->body)) << Doc::NewLine();
return doc;
}

Expand Down
Loading

0 comments on commit 35d1ae3

Please sign in to comment.