diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 74240d750694..c5e2ccd344c5 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -44,26 +44,28 @@ class BlockBuilder; */ class BlockBuilderNode : public Object { public: - BlockBuilderNode(std::shared_ptr name_table) : name_table_(name_table) {} + BlockBuilderNode(); ~BlockBuilderNode(); - BlockBuilderNode() { name_table_ = std::make_shared(); } - /*! \brief Begin to build a DataflowBlock. */ void BeginDataflowBlock(); + /*! \brief Begin to build a BindingBlock. */ void BeginBindingBlock(); + /*! * \brief End building a BindingBlock. * \return The BindingBlock being built. */ BindingBlock EndBlock(); + /*! * \brief Check if the block being built is DataflowBlock or not. * \return A boolean that indicates if the block being built is DataflowBlock or not. */ inline bool CurrentBlockIsDataFlow() { return CurrentFrame()->is_dataflow; } + /*! * \brief Emits an Expr, and returns the variable it is bound to. * \param expr The Expr to be emitted. @@ -71,12 +73,14 @@ class BlockBuilderNode : public Object { * \return The new variable that \p expr is bound to. */ virtual Var Emit(const Expr& expr, std::string name_hint = ""); + /*! * \brief Emits a variable binding, and returns the bound Var. * \param binding The variable binding. * \return The bound variable. */ virtual Var Emit(const VarBinding& binding); + /*! * \brief Emit a MatchShape. * \param value The value of the MatchShape to be emitted. @@ -85,12 +89,14 @@ class BlockBuilderNode : public Object { * \return The variable bound to the MatchShape. */ Var EmitMatchShape(const Expr& value, const Array& pattern, std::string name_hint = ""); + /*! * \brief Emit a MatchShape binding. * \param binding The MatchShape binding to be emitted. * \return The variable bound to the MatchShape. */ Var EmitMatchShape(const MatchShape& binding); + /*! * \brief Generate an output for the current dataflow block. * \param output The output variable of the block. @@ -98,18 +104,21 @@ class BlockBuilderNode : public Object { * \return The variable bound to \p output. */ Var EmitOutput(const Expr& output, std::string name_hint = ""); + /*! * \brief Generate an output for the current dataflow block. * \param binding The output binding to output. * \return The variable bound to \p output. */ Var EmitOutput(const VarBinding& binding); + /*! - * \brief Lookup a var in the binding table \p var_map_. + * \brief Lookup a var in the binding table \p binding_table_. * \param var The input var. * \return The Expr bound to the input \p var. */ - Expr LookupVar(const Var& var); + Expr LookupBinding(const Var& var); + /*! * \brief Check if two shape expressions can be proven equal at compile time. * \param lhs The input lhs shape. @@ -117,17 +126,20 @@ class BlockBuilderNode : public Object { * \return Whether we can prove lhs shape is the same as the rhs shape. */ bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs); + /*! - * \brief Normalize an Expr to complete its shape and type. - * \param expr The input expr. - * \return The expr with normalized shape and type. + * \brief Convert an expression to A-normal form, and try to eagerly infer types and shapes. + * \param expr The input expression. + * \return The normalized expression. */ Expr Normalize(const Expr& expr); + /*! - * \brief Create a BlockBuilder. - * \return The created BlockBuilder. + * \brief Get the name table for generating unique names. + * + * \return The name table. */ - TVM_DLL static BlockBuilder Create(); + NameTable* name_table(); void VisitAttrs(AttrVisitor* v) {} @@ -150,26 +162,45 @@ class BlockBuilderNode : public Object { Array bindings; bool is_dataflow; }; + + /*! + * \brief Utility class for performing IR normalization (conversion to ANF, eager forward shape + * and type inference). + */ + class ExprNormalizer; + friend class BlockBuilder; + /*! * \brief Get the current block frame. * \return The current block frame. */ BlockFrame* CurrentFrame(); + /*! \brief A stack to store block frames. */ std::stack block_stack_; + /*! \brief A diagnostic context for reporting errors. */ DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {})); + /*! \brief A binding table that maps var to value. */ - // TODO(@yuchen, @altanh): make var_map_ scoped, and decide if it should be in the builder - std::unordered_map var_map_; + std::unordered_map binding_table_; + /*! \brief A name table to get unique names for IR construction. */ - std::shared_ptr name_table_; + std::unique_ptr name_table_; + + /*! \brief The internal normalizer used for ANF conversion. */ + std::unique_ptr normalizer_; }; class BlockBuilder : public ObjectRef { public: - TVM_DLL explicit BlockBuilder(std::shared_ptr name_table); + /*! + * \brief Create a BlockBuilder. + * \return The created BlockBuilder. + */ + TVM_DLL static BlockBuilder Create(); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode); }; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 23000fa5bbeb..a20414a7a672 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -79,6 +79,7 @@ class ShapeExpr : public Expr { public: TVM_DLL explicit ShapeExpr(Array values, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode); }; /*! \brief The variable class for all Relax bindings. */ @@ -131,6 +132,7 @@ class Var : public Expr { TVM_DLL explicit Var(Id vid, runtime::Optional shape_annotation, runtime::Optional type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); }; /*! \brief A sub-type of the variable node used to mark dataflow variables from @@ -175,6 +177,7 @@ class DataflowVar : public Var { runtime::Optional type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode); }; /*! \brief The base class of a variable binding in Relax. */ @@ -235,6 +238,7 @@ class MatchShape : public Binding { TVM_DLL explicit MatchShape(Expr value, Array pattern, Var var, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchShapeNode); }; class VarBinding; @@ -266,6 +270,7 @@ class VarBinding : public Binding { public: TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode); }; class BindingBlock; @@ -296,6 +301,7 @@ class BindingBlock : public ObjectRef { public: TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode); }; class DataflowBlock; @@ -315,6 +321,7 @@ class DataflowBlock : public BindingBlock { public: TVM_DLL explicit DataflowBlock(Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode); }; /*! \brief A sequence of blocks followed by an expression. @@ -356,6 +363,7 @@ class SeqExpr : public Expr { public: TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode); }; /*! \brief A Relax function, eventually to replace the current Relay function definition. */ @@ -411,6 +419,7 @@ class Function : public Expr { TVM_DLL explicit Function(runtime::Optional name, Array params, Expr body, Type ret_type, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); }; /*! \brief The extern function, which can represent packed function. */ @@ -440,6 +449,7 @@ class ExternFunc : public Expr { public: TVM_DLL ExternFunc(String global_symbol, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); }; } // namespace relax diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 22a6c401f4ff..200867f8785f 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -178,16 +178,7 @@ void PostOrderVisit(const Expr& node, std::function fvisit); class ExprMutator : public ExprFunctor { public: ExprMutator() { - name_table_ = std::make_shared(); - builder_ = BlockBuilder(name_table_); - } - - /*! - * \brief Mutate is alias for VisitExpr - * \return expr. - */ - Expr Mutate(const Expr& expr) { - return this->VisitExpr(expr); + builder_ = BlockBuilder::Create(); } Expr VisitExpr(const Expr& expr) override; @@ -218,47 +209,60 @@ class ExprMutator : public ExprFunctor { virtual void VisitVarBinding(const VarBinding& binding); virtual void VisitMatchShape(const MatchShape& binding); + /*! + * \brief Rewrite the var definition site. + * \param var The var to be visited. + * \return The var after post-order rewritten. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual Var VisitVarDef(const Var& var); + virtual BindingBlock VisitBindingBlock(const BindingBlock& block); virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); protected: - Expr MutateWithPrologue(const Expr& expr, bool is_dataflow); + class ExprNormalizer; - /*! \brief Look up the value of a variable. If the variable is bound, then returns the bound - * value. Otherwise, returns the rewritten expression for the variable. + /*! + * \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If. + * \param expr The expr to be visited. + * \return The expr after visiting. */ - Expr LookupVar(Var var); + Expr VisitWithNewScope(const Expr& expr); - inline void UpdateMemo(Expr pre, Expr post) { - if (const VarNode* var = pre.as()) { - var_memo_[var->vid] = post; - } else { - expr_memo_[pre] = post; - } - } + /*! + * \brief Look up the value bound to a variable. + * \param var The var to be looked up. + * \return The value bound to the input \p var. + */ + Expr LookupBinding(const Var& var); - inline Optional LookupMemo(Expr pre) { - if (pre.as()) { - Id vid = Downcast(pre)->vid; - if (var_memo_.count(vid)) { - return var_memo_[vid]; - } - } else { - if (expr_memo_.count(pre)) { - return expr_memo_[pre]; - } - } - return NullOpt; + /*! + * \brief Post-order rewrite a node and normalize. + * \param T The node type to be rewritten. + * \param op The node to be rewritten. + * \return The node after post rewritten. + */ + template + Expr VisitExprPostOrder_(const T* op) { + return builder_->Normalize(ExprMutator::VisitExpr_(op)); } - /*! \brief Variable memoization table using Id equality */ - std::unordered_map var_memo_; - - /*! \brief Expr memoization table using pointer equality */ - std::unordered_map expr_memo_; + /*! + * \brief Create a new var with specified shape and type if it's original shape or type does not + * match with the specified ones. + * \param var The var to be updated. + * \param shape The specified shape. + * \param type The specified type. + * \return The var filled with \p shape and \p type. + */ + Var WithShapeAndType(Var var, Optional shape, Type type); - std::shared_ptr name_table_; + /*! \brief Internal block builder to emit bindings during rewriting. */ BlockBuilder builder_; + + /*! \brief Remap a var to a new var in use-site. */ + std::unordered_map var_remap_; }; // TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index aedda8157b54..e52631314666 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -73,3 +73,7 @@ def vm_shape_lower(mod: IRModule) -> IRModule: The input module. """ return _ffi_api.vm_shape_lower(mod) + + +def to_anf(mod: IRModule): + return _ffi_api.to_anf(mod) diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index 7cdcdebdcbe8..8b29c086d7ca 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -309,7 +309,9 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) { Doc RelaxScriptPrinter::VisitNode_(const relax::SeqExprNode* op) { Doc doc; + int i = 0; for (const relax::BindingBlock& block : op->blocks) { + doc << "# block " << i++ << Doc::NewLine(); doc << Print(block); } // NOTE: the body expression is printed in the parent, since SeqExprs are used for both Function @@ -484,11 +486,13 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& } doc << ":" << Doc::NewLine(4); - const relax::SeqExprNode* body = func->body.as(); - ICHECK(body) << "in the Relax IR normal form, the body of a Function should be a SeqExpr"; + if (const relax::SeqExprNode* body = func->body.as()) { + doc << Doc::Indent(4, Print(func->body)); + doc << Doc::Indent(4, Doc::Text("return ") << Print(body->body)) << Doc::NewLine(); + } else { + doc << Doc::Indent(4, Doc::Text("return ") << Print(func->body)) << Doc::NewLine(); + } - doc << Doc::Indent(4, Print(func->body)); - doc << Doc::Indent(4, Doc::Text("return ") << Print(body->body)) << Doc::NewLine(); return doc; } diff --git a/src/relax/backend/vm/vm_memory_lower.cc b/src/relax/backend/vm/vm_memory_lower.cc index c994aee7bc18..269ca141b0ff 100644 --- a/src/relax/backend/vm/vm_memory_lower.cc +++ b/src/relax/backend/vm/vm_memory_lower.cc @@ -49,7 +49,7 @@ class VMMemLowerMutator : public ExprMutator { for (auto& p : mod_->functions) { Expr func = p.second; if (p.second->IsInstance()) { - func = this->Mutate(p.second); + func = this->VisitExpr(p.second); } ret_mod->Add(p.first, Downcast(func)); } @@ -83,7 +83,7 @@ class VMMemLowerMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) override { // post-order mutation - Expr expr = ExprMutator::VisitExpr_(call); + Expr expr = VisitExprPostOrder_(call); call = expr.as(); static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 25e02785ec27..bbd2ed2c6d07 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -51,7 +51,7 @@ class VMShapeLowerMutator : public ExprMutator { shape_heap_ = Var("shape_heap", ShapeExpr({heap_size_}), heap_type); // mutate - func = this->Mutate(func); + func = this->VisitExpr(func); } ret_mod_->Add(p.first, Downcast(func)); } @@ -78,11 +78,11 @@ class VMShapeLowerMutator : public ExprMutator { return ExprMutator::VisitExpr_(node); } tir::PrimFunc func = CalculateShape(GetRef(node)); - std::string shape_func_name = name_table_->GetUniqueName("shape_func"); + std::string shape_func_name = builder_->name_table()->GetUniqueName("shape_func"); func = WithAttr(std::move(func), "global_symbol", runtime::String(shape_func_name)); GlobalVar shape_func_var(shape_func_name); // TODO make sure shape_heap doesnt get redefined by local funcs? - builder_->Emit(Call(shape_func_var, {shape_heap_}), "_compute_shape"); + builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); ret_mod_->Add(shape_func_var, func); // construct shape @@ -100,7 +100,7 @@ class VMShapeLowerMutator : public ExprMutator { Expr VisitExpr_(const FunctionNode* node) override { Array params; for (Var param : node->params) { - params.push_back(Downcast(this->Mutate(param))); + params.push_back(this->VisitVarDef(param)); } Type ret_type = this->VisitType(node->ret_type); @@ -108,7 +108,7 @@ class VMShapeLowerMutator : public ExprMutator { builder_->Emit(VarBinding( shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})}))); - Expr new_body = this->Mutate(node->body); + Expr new_body = this->VisitExpr(node->body); Array blocks; diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 6f32123fb2d9..f9f0d9262a1c 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -30,19 +31,260 @@ namespace tvm { namespace relax { +// ================================ +// BlockBuilderNode::ExprNormalizer + +// TODO(@altanh): more test cases to cover different visits +class BlockBuilderNode::ExprNormalizer : public ExprFunctor { + public: + ExprNormalizer(BlockBuilderNode* builder) : builder_(builder) {} + +#define RELAX_EXPR_NORMALIZER_LEAF(OP) \ + Expr VisitExpr_(const OP* op) final { return GetRef(op); } + + RELAX_EXPR_NORMALIZER_LEAF(ConstantNode); + RELAX_EXPR_NORMALIZER_LEAF(VarNode); + RELAX_EXPR_NORMALIZER_LEAF(DataflowVarNode); + RELAX_EXPR_NORMALIZER_LEAF(ShapeExprNode); + RELAX_EXPR_NORMALIZER_LEAF(ExternFuncNode); + RELAX_EXPR_NORMALIZER_LEAF(GlobalVarNode); + RELAX_EXPR_NORMALIZER_LEAF(OpNode); + + // TODO(@altanh): CopyOnWrite + + Expr VisitExpr(const Expr& expr) { + Optional post = expr_memo_.Get(expr); + if (post) { + ICHECK(post.as()) << "memoized expressions should map to variables"; + return post.value(); + } + return ExprFunctor::VisitExpr(expr); + } + + Expr VisitExpr_(const TupleNode* op) final { + bool unchanged = true; + Array new_fields; + for (const Expr& field : op->fields) { + Expr new_field = this->Bind(field); + new_fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + return unchanged ? GetRef(op) : Tuple(new_fields); + } + + Expr VisitExpr_(const FunctionNode* op) final { + Expr new_body = this->VisitWithNewScope(op->body); + if (new_body.same_as(op->body)) { + return GetRef(op); + } + return Function(op->name, op->params, new_body, op->ret_type); + } + + Expr VisitExpr_(const CallNode* op) final { + Expr new_op = this->VisitExpr(op->op); + bool unchanged = new_op.same_as(op->op); + + Array new_args; + for (const Expr& arg : op->args) { + Expr new_arg = this->Bind(arg); + new_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + + if (unchanged) { + return GetRef(op); + } + return Call(new_op, new_args, op->attrs, op->type_args); + } + + Expr VisitExpr_(const SeqExprNode* op) final { + bool unchanged = true; + Array new_blocks; + for (const BindingBlock& block : op->blocks) { + // TODO(@altanh): we could merge sequential non-dataflow BindingBlocks here + BindingBlock new_block = this->VisitBindingBlock(block); + new_blocks.push_back(new_block); + unchanged &= new_block.same_as(block); + } + + builder_->BeginBindingBlock(); + Expr new_body = this->VisitExpr(op->body); + unchanged &= new_body.same_as(op->body); + BindingBlock prologue = builder_->EndBlock(); + + // TODO(@altanh, @yuchen): normalize nested SeqExprs and BindingBlocks + + if (!prologue->bindings.empty()) { + new_blocks.push_back(prologue); + unchanged = false; + } + + if (unchanged) { + return GetRef(op); + } + return SeqExpr(new_blocks, new_body); + } + + Expr VisitExpr_(const IfNode* op) final { + Expr new_cond = this->VisitExpr(op->cond); + Expr new_true = this->VisitWithNewScope(op->true_branch); + Expr new_false = this->VisitWithNewScope(op->false_branch); + if (new_cond.same_as(op->cond) && new_true.same_as(op->true_branch) && + new_false.same_as(op->false_branch)) { + return GetRef(op); + } + return If(new_cond, new_true, new_false); + } + + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr new_tuple = this->VisitExpr(op->tuple); + if (new_tuple.same_as(op->tuple)) { + return GetRef(op); + } + return TupleGetItem(new_tuple, op->index); + } + + Binding VisitBinding(const Binding& binding) { + if (binding.as()) { + return this->VisitVarBinding(Downcast(binding)); + } else { + ICHECK(binding.as()) << "expected VarBinding or MatchShape, got " << binding; + return this->VisitMatchShape(Downcast(binding)); + } + } + + VarBinding VisitVarBinding(const VarBinding& binding) { + Expr new_value = this->VisitExpr(binding->value); + if (new_value.same_as(binding->value) || new_value.same_as(binding->var)) { + // if new_value = binding->var, then we have found an ANF binding site, so just return it + return binding; + } + return VarBinding(binding->var, new_value); + } + + MatchShape VisitMatchShape(const MatchShape& binding) { + Expr new_value = this->VisitExpr(binding->value); + if (new_value.same_as(binding->value)) { + return binding; + } + return MatchShape(new_value, binding->pattern, binding->var); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) { + if (block.as()) { + builder_->BeginDataflowBlock(); + } else { + builder_->BeginBindingBlock(); + } + + bool unchanged = true; + for (const Binding& binding : block->bindings) { + Binding new_binding = this->VisitBinding(binding); + unchanged &= new_binding.same_as(binding); + if (new_binding.as()) { + VarBinding var_binding = Downcast(new_binding); + if (builder_->CurrentBlockIsDataFlow() && !var_binding->var.as()) { + builder_->EmitOutput(var_binding); + } else { + builder_->Emit(var_binding); + } + } else { + ICHECK(new_binding.as()); + builder_->EmitMatchShape(Downcast(new_binding)); + } + } + BindingBlock new_block = builder_->EndBlock(); + unchanged &= new_block->bindings.size() == block->bindings.size(); + if (unchanged) { + return block; + } + return new_block; + } + + private: + /*! + * \brief Memoization map for expressions using Id for equality of variables. + */ + class ExprMemo { + public: + Optional Get(const Expr& expr) { + if (const VarNode* var = expr.as()) { + auto it = var_memo_.find(var->vid); + if (it != var_memo_.end()) { + return it->second; + } + } else { + auto it = expr_memo_.find(expr); + if (it != expr_memo_.end()) { + return it->second; + } + } + return NullOpt; + } + + void Set(const Expr& pre, const Expr& post) { + if (const VarNode* var = pre.as()) { + var_memo_[var->vid] = post; + } else { + expr_memo_[pre] = post; + } + } + + private: + std::unordered_map var_memo_; + std::unordered_map expr_memo_; + }; + + static bool IsLeaf(const Expr& expr) { + // NB: tuples are treated as leaf nodes for ergonomics + // TODO(@altanh, @yuchen): remove TupleNode from leaf + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as() || + expr.as(); + } + + Expr VisitWithNewScope(const Expr& expr) { + builder_->BeginBindingBlock(); + Expr post = this->VisitExpr(expr); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + post = SeqExpr({prologue}, post); + } + return post; + } + + Expr Bind(const Expr& expr) { + Expr post = this->VisitExpr(expr); + if (!IsLeaf(post)) { + post = builder_->Emit(post); + expr_memo_.Set(expr, post); + } + return post; + } + + /*! \brief BlockBuilder used for emitting intermediate variables. */ + BlockBuilderNode* builder_; + + /*! \brief Memoization table for mapping expressions to their ANF variables. */ + ExprMemo expr_memo_; +}; + +// ================ +// BlockBuilderNode + TVM_REGISTER_NODE_TYPE(BlockBuilderNode); +BlockBuilderNode::BlockBuilderNode() { + name_table_ = std::make_unique(); + normalizer_ = std::make_unique(this); +} + BlockBuilderNode::~BlockBuilderNode() { if (!block_stack_.empty()) { LOG(WARNING) << "BlockBuilder destroyed with remaining blocks!"; } } -BlockBuilder BlockBuilderNode::Create() { - BlockBuilder ret(make_object()); - return ret; -} - void BlockBuilderNode::BeginDataflowBlock() { this->block_stack_.push({{}, true}); } void BlockBuilderNode::BeginBindingBlock() { this->block_stack_.push({{}, false}); } @@ -106,7 +348,7 @@ Var BlockBuilderNode::Emit(const Expr& expr, bool is_dataflow, std::string name_ new_call->shape_ = inferred_shape; cur_frame->bindings.push_back(VarBinding(var, new_call)); - this->var_map_[var->vid] = new_call; + this->binding_table_[var->vid] = new_call; } else if (const VarNode* var_node = expr.as()) { const Var& lhs_var = GetRef(var_node); if (lhs_var->shape_.defined()) { @@ -116,12 +358,10 @@ Var BlockBuilderNode::Emit(const Expr& expr, bool is_dataflow, std::string name_ var->checked_type_ = lhs_var->checked_type_; } cur_frame->bindings.push_back(VarBinding(var, lhs_var)); - this->var_map_[var->vid] = lhs_var; - } - - else { + this->binding_table_[var->vid] = lhs_var; + } else { cur_frame->bindings.push_back(VarBinding(var, expr)); - this->var_map_[var->vid] = expr; + binding_table_[var->vid] = expr; } return var; @@ -133,7 +373,7 @@ Var BlockBuilderNode::Emit(const VarBinding& binding) { ICHECK(binding->var.as()); } cur_frame->bindings.push_back(binding); - this->var_map_[binding->var->vid] = binding->value; + binding_table_[binding->var->vid] = binding->value; return binding->var; } @@ -168,11 +408,16 @@ Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array& p Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) { BlockFrame* cur_frame = CurrentFrame(); - if (cur_frame->is_dataflow) { + if (cur_frame->is_dataflow && binding->var.defined()) { ICHECK(!binding->var.as()) << "cannot bind DataflowVar outside dataflow block."; } cur_frame->bindings.push_back(binding); + // TODO(@altanh, @yuchen): what value should we bind? Consider + // y = add(x, x) + // z = match_shape(y, (n, m)) + // We would like pass writers to match "z" with the "add" node but with extra shape info. + // Maybe this logic could be deferred to a DFPattern-style rewriter? return binding->var; } @@ -191,13 +436,13 @@ Var BlockBuilderNode::EmitOutput(const VarBinding& binding) { ICHECK(!binding->var.as()) << "EmitOutput can only emit Var bindings."; cur_frame->bindings.push_back(binding); - this->var_map_[binding->var->vid] = binding->value; + binding_table_[binding->var->vid] = binding->value; return binding->var; } -Expr BlockBuilderNode::LookupVar(const Var& var) { - auto it = this->var_map_.find(var->vid); - if (it == this->var_map_.end()) { +Expr BlockBuilderNode::LookupBinding(const Var& var) { + auto it = binding_table_.find(var->vid); + if (it == binding_table_.end()) { this->diag_ctx_.EmitFatal(Diagnostic::Error(var->span) << "The var to be looked up is not in the binding table."); } @@ -229,23 +474,36 @@ bool BlockBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) { return false; } -// TODO(@altanh, @yuchen): emit expr in ssa form +// TODO(@altanh, @yuchen): need an internal Emit_ that doesn't call normalize Expr BlockBuilderNode::Normalize(const Expr& expr) { - if (expr.as()) { - Call call = Downcast(expr); + // TODO(@altanh): fast path + Expr normalized = normalizer_->VisitExpr(expr); + if (normalized.as()) { + // FIXME(@altanh): potentially breaks idempotency + Call call = Downcast(normalized); + + // only do shape/type inference if the call does not have shape/type + if (call->shape_ && call->checked_type_.defined()) { + return call; + } + // Shape inference - auto inferred_shape = InferShape(call, this->diag_ctx_); - if (inferred_shape.defined()) { - if (auto* shape_expr = inferred_shape.value().as()) { - call->shape_ = GetRef(shape_expr); + if (!call->shape_) { + auto inferred_shape = InferShape(call, this->diag_ctx_); + if (inferred_shape) { + call->shape_ = this->Normalize(inferred_shape.value()); } } - // Type inference - auto inferred_type = InferType(call, this->diag_ctx_); - call->checked_type_ = inferred_type; + + if (!call->checked_type_.defined()) { + // Type inference + auto inferred_type = InferType(call, this->diag_ctx_); + call->checked_type_ = inferred_type; + } + return call; } - return expr; + return normalized; } BlockBuilderNode::BlockFrame* BlockBuilderNode::CurrentFrame() { @@ -253,13 +511,15 @@ BlockBuilderNode::BlockFrame* BlockBuilderNode::CurrentFrame() { return &block_stack_.top(); } -BlockBuilder::BlockBuilder(std::shared_ptr name_table) { - ObjectPtr n = make_object(); - n->name_table_ = name_table; - data_ = std::move(n); +NameTable* BlockBuilderNode::name_table() { + return name_table_.get(); +} + +BlockBuilder BlockBuilder::Create() { + return BlockBuilder(make_object()); } -TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed(BlockBuilderNode::Create); +TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed(BlockBuilder::Create); TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") .set_body_typed([](BlockBuilder builder) { builder->BeginDataflowBlock(); }); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 7efa104c36bd..be80b38eb28d 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -180,19 +180,15 @@ TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr ex // ================== // ExprMutator -Expr ExprMutator::VisitExpr_(const ConstantNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); } Expr ExprMutator::VisitExpr_(const TupleNode* op) { tvm::Array fields; bool all_fields_unchanged = true; for (Expr field : op->fields) { - Expr new_field = this->Mutate(field); + Expr new_field = this->VisitExpr(field); fields.push_back(new_field); all_fields_unchanged &= new_field.same_as(field); } @@ -204,24 +200,24 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { } } +// Visit the use-site of a defined Var Expr ExprMutator::VisitExpr_(const VarNode* op) { - if (op->type_annotation.defined()) { - Type type = this->VisitType(op->type_annotation.value()); - if (!op->type_annotation.same_as(type)) { - return Var(op->vid, Downcast(op->shape()), type, op->span); - } + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; } + // default case return self. return GetRef(op); } +// Visit the use-site of a defined DataflowVar Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { - if (op->type_annotation.defined()) { - Type type = this->VisitType(op->type_annotation.value()); - if (!op->type_annotation.same_as(type)) { - return DataflowVar(op->vid, Downcast(op->shape()), type, op->span); - } + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; } + // default case return self. return GetRef(op); } @@ -230,13 +226,13 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { tvm::Array params; bool all_params_unchanged = true; for (Var param : op->params) { - Var new_param = Downcast(this->Mutate(param)); + Var new_param = this->VisitVarDef(param); params.push_back(new_param); all_params_unchanged &= param.same_as(new_param); } Type ret_type = this->VisitType(op->ret_type); - Expr body = this->MutateWithPrologue(op->body, false); + Expr body = this->VisitWithNewScope(op->body); if (all_params_unchanged && ret_type.same_as(op->ret_type) && body.same_as(op->body)) { return GetRef(op); @@ -246,7 +242,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { } Expr ExprMutator::VisitExpr_(const CallNode* call_node) { - Expr new_op = this->Mutate(call_node->op); + Expr new_op = this->VisitExpr(call_node->op); bool unchanged = call_node->op.same_as(new_op); tvm::Array ty_args; @@ -258,7 +254,7 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { tvm::Array call_args; for (Expr arg : call_node->args) { - Expr new_arg = this->Mutate(arg); + Expr new_arg = this->VisitExpr(arg); call_args.push_back(new_arg); unchanged &= new_arg.same_as(arg); } @@ -271,9 +267,9 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { } Expr ExprMutator::VisitExpr_(const IfNode* op) { - Expr guard = this->Mutate(op->cond); - Expr true_b = this->MutateWithPrologue(op->true_branch, false); - Expr false_b = this->MutateWithPrologue(op->false_branch, false); + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { return GetRef(op); @@ -285,21 +281,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef(op); } Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) { - auto t = this->Mutate(get_item->tuple); - if (get_item->tuple == t) { + auto t = this->VisitExpr(get_item->tuple); + if (get_item->tuple.same_as(t)) { return GetRef(get_item); } else { return TupleGetItem(t, get_item->index, get_item->span); } } -Expr ExprMutator::VisitExpr_(const ShapeExprNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const ShapeExprNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const ExternFuncNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const ExternFuncNode* op) { return GetRef(op); } Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { bool all_blocks_unchanged = true; @@ -313,7 +305,7 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { } builder_->BeginBindingBlock(); - Expr body = this->Mutate(op->body); + Expr body = this->VisitExpr(op->body); BindingBlock prologue = builder_->EndBlock(); if (!prologue->bindings.empty()) { blocks.push_back(prologue); @@ -340,42 +332,21 @@ void ExprMutator::VisitBinding(const Binding& binding) { } void ExprMutator::VisitVarBinding(const VarBinding& binding) { - Expr new_value = builder_->Normalize(this->Mutate(binding->value)); - - // TODO(@altanh): this probably shouldn't live here, all passes would have to make sure to do it - // in this method... - // if (new_value->shape_.defined()) { - // if (new_var->shape_.defined()) { - // new_var = Var(new_var->vid, NullOpt, new_var->type_annotation, new_var->span); - // } - // new_var->shape_ = new_value->shape_; - // } - // if (new_value->checked_type_.defined()) { - // if (new_var->checked_type_.defined()) { - - // } - // new_var = Var(new_var->vid, new_var->shape_, NullOpt, new_var->span); - // new_var->checked_type_ = new_value->checked_type_; - // } - - Var new_var = Downcast(this->Mutate(binding->var)); - if (!builder_->CanProveShapeEqual(new_var->shape(), new_value->shape()) || - !StructuralEqual()(new_var->checked_type(), new_value->checked_type())) { - // TODO(@altanh): use CopyOnWrite and/or type inference machinery here - if (new_var.as()) { - new_var = DataflowVar(new_var->vid, NullOpt, NullOpt, new_var->span); - } else { - new_var = Var(new_var->vid, NullOpt, NullOpt, new_var->span); - } - if (new_value->shape_.defined()) { - new_var->shape_ = new_value->shape_; - } - // TODO(@yuchen, @altanh): checked_type_.defined() needs to change depends on how to represent unknown type - if (new_value->checked_type_.defined()){ - new_var->checked_type_ = new_value->checked_type_; - } + Expr new_value = this->VisitExpr(binding->value); + Var new_var = this->VisitVarDef(binding->var); - UpdateMemo(binding->var, new_var); + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + // no-op if there is no change + builder_->Emit(binding); + return; + } + + { + Var temp = WithShapeAndType(new_var, new_value->shape_, new_value->checked_type_); + if (!temp.same_as(new_var)) { + new_var = temp; + this->var_remap_[binding->var->vid] = new_var; + } } if (builder_->CurrentBlockIsDataFlow() && !new_var.as()) { @@ -386,17 +357,28 @@ void ExprMutator::VisitVarBinding(const VarBinding& binding) { } void ExprMutator::VisitMatchShape(const MatchShape& binding) { - Expr new_value = this->Mutate(binding->value); - Expr new_pattern = this->Mutate(ShapeExpr(binding->pattern)); + Expr new_value = this->VisitExpr(binding->value); + Expr new_pattern = this->VisitExpr(ShapeExpr(binding->pattern)); + Var new_var; - if (binding->var.defined()){ - new_var = Downcast(this->Mutate(binding->var)); - // TODO(@altanh, @yuchen): shape and type inference here too... - } else { - new_var = binding->var; + if (binding->var.defined()) { + // in the case of `x = R.match_shape(val, pattern)`, we want `x` to directly get `pattern` as + // the shape when `val` is a tensor. + Optional new_shape; + if (new_value->checked_type_.defined() && new_value->checked_type_.as()) { + new_shape = new_pattern; + } + Var temp = + WithShapeAndType(this->VisitVarDef(binding->var), new_shape, new_value->checked_type_); + if (!temp.same_as(new_var)) { + new_var = temp; + this->var_remap_[binding->var->vid] = new_var; + } } + // TODO(@altanh, @yuchen): shape and type inference here too... // TODO: when value's shape/type changed, create new var + // TODO: group the can prove shape/type logic and replace var into a function builder_->EmitMatchShape( MatchShape(new_value, Downcast(new_pattern)->values, new_var)); } @@ -421,25 +403,31 @@ BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) { return builder_->EndBlock(); } -Expr ExprMutator::VisitExpr(const Expr& expr) { - Optional post = LookupMemo(expr); - if (post) { - return post.value(); +Var ExprMutator::VisitVarDef(const Var& var) { + if (var->type_annotation.defined()) { + Type type = this->VisitType(var->type_annotation.value()); + if (!var->type_annotation.same_as(type)) { + Var new_var; + if (var.as()) { + new_var = DataflowVar(var->vid, NullOpt, type, var->span); + } else { + new_var = Var(var->vid, NullOpt, type, var->span); + } + new_var->shape_ = var->shape_; + this->var_remap_[var->vid] = new_var; + return new_var; + } } - - UpdateMemo(expr, ExprFunctor::VisitExpr(expr)); - - return LookupMemo(expr).value(); + return var; } -Expr ExprMutator::MutateWithPrologue(const Expr& expr, bool is_dataflow) { - if (is_dataflow) { - builder_->BeginDataflowBlock(); - } else { - builder_->BeginBindingBlock(); - } +Expr ExprMutator::VisitExpr(const Expr& expr) { + return builder_->Normalize(ExprFunctor::VisitExpr(expr)); +} - Expr ret = this->Mutate(expr); +Expr ExprMutator::VisitWithNewScope(const Expr& expr) { + builder_->BeginBindingBlock(); + Expr ret = this->VisitExpr(expr); BindingBlock prologue = builder_->EndBlock(); if (!prologue->bindings.empty()) { ret = SeqExpr({prologue}, ret); @@ -447,19 +435,36 @@ Expr ExprMutator::MutateWithPrologue(const Expr& expr, bool is_dataflow) { return ret; } -Expr ExprMutator::LookupVar(Var var) { - // cases: - // 1. var has been rewritten to some expr (e.g. a constant) and is no longer bound - // 2. var remains bound to some expr - // 3. var is deleted, in which case this should never be called - Expr mutated_var = LookupMemo(var).value(); - if (mutated_var.as()) { - // lookup bound var in the builder - return builder_->LookupVar(Downcast(mutated_var)); +Expr ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } + +Var ExprMutator::WithShapeAndType(Var var, Optional shape, Type type) { + // shape/type changes if it goes from defined -> undefined or the other way, hence xor + bool shape_changed = var->shape_.operator bool() ^ shape.operator bool(); + shape_changed |= var->shape_ && shape && + !builder_->CanProveShapeEqual(Downcast(var->shape_.value()), + Downcast(shape.value())); + + bool type_changed = var->checked_type_.defined() ^ type.defined(); + type_changed |= var->checked_type_.defined() && type.defined() && + !StructuralEqual()(var->checked_type_, type); + + if (shape_changed || type_changed) { + Var new_var = var.as() ? DataflowVar(var->vid, NullOpt, NullOpt, var->span) + : Var(var->vid, NullOpt, NullOpt, var->span); + new_var->shape_ = var->shape_; + new_var->checked_type_ = var->checked_type_; + var = new_var; + } + + if (shape_changed) { + var->shape_ = shape; + } + + if (type_changed) { + var->checked_type_ = type; } - // return the rewritten var value - return mutated_var; + return var; } // ================== diff --git a/src/relax/transform/call_dps_rewrite.cc b/src/relax/transform/call_dps_rewrite.cc index 76ef84589817..226b92c0169b 100644 --- a/src/relax/transform/call_dps_rewrite.cc +++ b/src/relax/transform/call_dps_rewrite.cc @@ -48,7 +48,7 @@ class CallDPSMutator : public ExprMutator { for (auto& p : mod_->functions) { Expr func = p.second; if (p.second->IsInstance()) { - func = this->Mutate(p.second); + func = this->VisitExpr(p.second); } ret_mod->Add(p.first, Downcast(func)); } @@ -57,7 +57,7 @@ class CallDPSMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) override { // post-order mutation - Expr expr = ExprMutator::VisitExpr_(call); + Expr expr = VisitExprPostOrder_(call); call = expr.as(); static const Op& call_dps_op = Op::Get("relax.call_dps"); @@ -65,7 +65,7 @@ class CallDPSMutator : public ExprMutator { if (call->op == call_dps_op) { ShapeExpr output_shape = Downcast(call->args[0]); - Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc"); + Var tensor = builder_->Emit(Call(alloc_tensor_op, {output_shape}), "tensor"); builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_"); return tensor; } diff --git a/src/relax/transform/fma_rewrite.cc b/src/relax/transform/fma_rewrite.cc index 8108832ff068..15f23a6cb534 100644 --- a/src/relax/transform/fma_rewrite.cc +++ b/src/relax/transform/fma_rewrite.cc @@ -43,7 +43,7 @@ namespace relax { class EwiseFMARewriter : public ExprMutator { Expr VisitExpr_(const CallNode* call) override { - Expr expr = ExprMutator::VisitExpr_(call); + Expr expr = VisitExprPostOrder_(call); call = expr.as(); static const Op& add_op = Op::Get("relax.add"); @@ -52,7 +52,8 @@ class EwiseFMARewriter : public ExprMutator { if (call->op == add_op) { // NOTE: assumes df block is completely SSA - Expr value = LookupVar(Downcast(call->args[0])); + // FIXME(@altanh, @yuchen): this will crash if args[0] isn't a Var + Expr value = LookupBinding(Downcast(call->args[0])); const CallNode* mul = value.as(); if (mul && mul->op == multiply_op) { Call fma_call = Call(ewise_fma_op, {mul->args[0], mul->args[1], call->args[1]}, {}, {}); @@ -65,7 +66,7 @@ class EwiseFMARewriter : public ExprMutator { }; Expr FMARewrite(const Expr& e) { - return EwiseFMARewriter().Mutate(e); + return EwiseFMARewriter().VisitExpr(e); } TVM_REGISTER_GLOBAL("relax.transform.fma_rewrite") diff --git a/src/relax/transform/to_anf.cc b/src/relax/transform/to_anf.cc new file mode 100644 index 000000000000..abcdc2166fff --- /dev/null +++ b/src/relax/transform/to_anf.cc @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/to_anf.cc + * \brief Pass for transforming Relax IR to A-normal form. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + + +// TODO(@altanh): LCA binding lifting +class ToANFMutator : public ExprMutator { + public: + ToANFMutator(const IRModule& mod) : mod_(mod) {} + + IRModule Lower() { + IRModule ret_mod = IRModule(); + for (auto& p : mod_->functions) { + Expr func = p.second; + if (p.second->IsInstance()) { + func = this->VisitExpr(p.second); + } + ret_mod->Add(p.first, Downcast(func)); + } + return ret_mod; + } + + private: + IRModule mod_; +}; + +TVM_REGISTER_GLOBAL("relax.transform.to_anf").set_body_typed([](IRModule mod) { + return ToANFMutator(mod).Lower(); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index 63fea270dca6..3f58f6e5dc1d 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -37,15 +37,21 @@ class ToNonDFMutator : public ExprMutator { for (auto& p : mod_->functions) { Expr func = p.second; if (p.second->IsInstance()) { - func = this->Mutate(p.second); + func = this->VisitExpr(p.second); } ret_mod->Add(p.first, Downcast(func)); } return ret_mod; } - Expr VisitExpr_(const DataflowVarNode* op) final { - return Var(op->vid, op->shape(), op->type_annotation, op->span); + Var VisitVarDef(const Var& var) final { + if (var.as()){ + Var new_var = Var(var->vid, NullOpt, var->checked_type_, var->span); + new_var->shape_ = var->shape_; + this->var_remap_[var->vid] = new_var; + return new_var; + } + return var; } BindingBlock VisitDataflowBlock(const DataflowBlock& block) final { diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index fa6a9e3ee6e4..6375ccae28fb 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -103,7 +103,7 @@ def fvisit(e): assert isinstance(gv1, relax.Var) assert isinstance(new_vars[5], relax.Var) - assert gv1 != new_vars[5] + assert gv1 == new_vars[5] def test_call_dps_rewrite(): @@ -204,6 +204,32 @@ def foo(x: Tensor[_, "float32"]) -> Shape: assert isinstance(s5, tvm.relay.Call) assert s5.op.name == "relax.vm.builtin.load_shape" +def test_to_anf(): + x = relax.Var("x", type_annotation=relax.DynTensorType()) + gv = relax.op.add(x, x) + gv1 = relax.op.add(gv, gv) + gv2 = relax.op.add(gv, gv1) + body = relax.Tuple([gv, gv2]) + gvar = relax.GlobalVar("f") + func = relax.Function([x], body, None, gvar) + + mod: tvm.IRModule = tvm.IRModule({gvar: func}) + mod = relax.transform.to_anf(mod) + + @tvm.script.ir_module + class TestToANFExpected: + @R.function + def f(x: Tensor[_, "float32"]): + gv = relax.add(x, x) + gv1 = relax.add(gv, gv) + gv2 = relax.add(gv, gv1) + return (gv, gv2) + + # TODO(@altanh): fix this once type inference works properly...? + assert R.parser.astext(mod) == R.parser.astext(TestToANFExpected) + + + if __name__ == "__main__": test_fma_rewrite() test_to_non_dataflow()