diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 200867f8785f..15c5617dfc01 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -136,13 +136,19 @@ class ExprFunctor { } }; + /*! * \brief A simple visitor wrapper around ExprFunctor. * Recursively visit the content. */ -class ExprVisitor : public ExprFunctor { +class ExprVisitor : public ExprFunctor { public: + /*! + * \brief Generic dispatcher for Expr. + * \param expr The expr to be visited. + */ void VisitExpr(const Expr& expr) override; + // specific leaf level visitor functions void VisitExpr_(const ConstantNode* op) override; void VisitExpr_(const TupleNode* op) override; void VisitExpr_(const VarNode* op) override; @@ -157,13 +163,36 @@ class ExprVisitor : public ExprFunctor { void VisitExpr_(const OpNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override; - virtual void VisitType(const Type& t); - virtual void VisitSpan(const Span& span); + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ virtual void VisitBinding(const Binding& binding); - virtual void VisitVarBinding(const VarBinding& binding); - virtual void VisitMatchShape(const MatchShape& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchShapeNode* binding); + + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + */ virtual void VisitBindingBlock(const BindingBlock& block); - virtual void VisitDataflowBlock(const DataflowBlock& block); + // specific leaf level visitor functions + virtual void VisitBindingBlock_(const BindingBlockNode* block); + virtual void VisitBindingBlock_(const DataflowBlockNode* block); + + /*! + * \brief Generic dispatcher for visiting the var definition site. + * \param var The var to be visited. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual void VisitVarDef(const Var& var); + // specific leaf level visitor functions + virtual void VisitVarDef_(const VarNode* var); + virtual void VisitVarDef_(const DataflowVarNode* var); + + virtual void VisitType(const Type& t); + virtual void VisitSpan(const Span& span); }; void PostOrderVisit(const Expr& node, std::function fvisit); @@ -205,20 +234,35 @@ class ExprMutator : public ExprFunctor { */ virtual Type VisitType(const Type& t); + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ virtual void VisitBinding(const Binding& binding); - virtual void VisitVarBinding(const VarBinding& binding); - virtual void VisitMatchShape(const MatchShape& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchShapeNode* binding); + + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + * \return The binding block after transformation. + */ + virtual BindingBlock VisitBindingBlock(const BindingBlock& block); + // specific leaf level visitor functions + virtual BindingBlock VisitBindingBlock_(const BindingBlockNode* block); + virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode* block); /*! - * \brief Rewrite the var definition site. + * \brief Generic dispatcher for rewriting the var definition site. * \param var The var to be visited. * \return The var after post-order rewritten. * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var */ virtual Var VisitVarDef(const Var& var); - - virtual BindingBlock VisitBindingBlock(const BindingBlock& block); - virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); + // specific leaf level visitor functions + virtual Var VisitVarDef_(const VarNode* var); + virtual Var VisitVarDef_(const DataflowVarNode* var); protected: class ExprNormalizer; @@ -265,16 +309,6 @@ class ExprMutator : public ExprFunctor { std::unordered_map var_remap_; }; -// TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks -/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes - */ -class DataflowMutator : public ExprMutator { - public: - void VisitBinding(const Binding& binding) final; - - virtual void VisitDataflowVarBinding(const VarBinding& binding); -}; - } // namespace relax } // namespace tvm #endif // TVM_RELAX_EXPR_FUNCTOR_H_ diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 38ec2e8ae4be..14ed499180b8 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -57,7 +57,7 @@ class VMShapeLowerMutator : public ExprMutator { return ret_mod_; } - void VisitMatchShape(const MatchShape& binding) override { + void VisitBinding_(const MatchShapeNode* binding) override { Expr shape = ExprMutator::VisitExpr(binding->value); static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape"); auto store_shape_attr = make_object(); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index be80b38eb28d..bb1a1c58d96c 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -61,7 +61,7 @@ void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { void ExprVisitor::VisitExpr_(const FunctionNode* op) { this->VisitSpan(op->span); for (Var param : op->params) { - this->VisitExpr(param); + this->VisitVarDef(param); } this->VisitExpr(op->body); @@ -110,50 +110,79 @@ void ExprVisitor::VisitType(const Type& t) {} void ExprVisitor::VisitSpan(const Span& span) {} -void ExprVisitor::VisitBinding(const Binding& binding) { - if (binding.as()) { - this->VisitVarBinding(Downcast(binding)); - } else if (binding.as()) { - this->VisitMatchShape(Downcast(binding)); - } else { - LOG(FATAL) << "Wrong type."; - } -} - -void ExprVisitor::VisitVarBinding(const VarBinding& binding) { +void ExprVisitor::VisitBinding_(const VarBindingNode* binding) { this->VisitExpr(binding->value); - this->VisitExpr(binding->var); + this->VisitVarDef(binding->var); } -void ExprVisitor::VisitMatchShape(const MatchShape& binding) { +void ExprVisitor::VisitBinding_(const MatchShapeNode* binding) { this->VisitExpr(binding->value); // TODO(ziheng): should we change pattern from // Array to ShapeExpr? this->VisitExpr(ShapeExpr(binding->pattern)); if (binding->var.defined()) { - this->VisitExpr(binding->var); + this->VisitVarDef(binding->var); } } -void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { - if (block.as()) { - this->VisitDataflowBlock(Downcast(block)); - } else { - for (Binding binding : block->bindings) { - this->VisitBinding(binding); - } +void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); } } -void ExprVisitor::VisitDataflowBlock(const DataflowBlock& block) { +void ExprVisitor::VisitBindingBlock_(const DataflowBlockNode* block) { for (Binding binding : block->bindings) { this->VisitBinding(binding); } } +void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { + this->VisitSpan(var->span); + if (var->type_annotation.defined()) { + this->VisitType(var->type_annotation.value()); + } +} + +void ExprVisitor::VisitVarDef_(const VarNode* var) { + this->VisitSpan(var->span); + if (var->type_annotation.defined()) { + this->VisitType(var->type_annotation.value()); + } +} + void ExprVisitor::VisitExpr(const Expr& expr) { - using TParent = ExprFunctor; - TParent::VisitExpr(expr); + ExprFunctor::VisitExpr(expr); +} + +void ExprVisitor::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { + if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } +} + +void ExprVisitor::VisitVarDef(const Var& var) { + if (const auto* node = var.as()) { + VisitVarDef_(node); + } else if (const auto* node = var.as()) { + VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } } class ExprApplyVisit : public ExprVisitor { @@ -321,23 +350,13 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { Type ExprMutator::VisitType(const Type& t) { return t; } -void ExprMutator::VisitBinding(const Binding& binding) { - if (binding.as()) { - this->VisitVarBinding(Downcast(binding)); - } else if (binding.as()) { - this->VisitMatchShape(Downcast(binding)); - } else { - LOG(FATAL) << "Wrong type."; - } -} - -void ExprMutator::VisitVarBinding(const VarBinding& binding) { +void ExprMutator::VisitBinding_(const VarBindingNode* binding) { Expr new_value = this->VisitExpr(binding->value); Var new_var = this->VisitVarDef(binding->var); if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { // no-op if there is no change - builder_->Emit(binding); + builder_->Emit(GetRef(binding)); return; } @@ -356,7 +375,7 @@ void ExprMutator::VisitVarBinding(const VarBinding& binding) { } } -void ExprMutator::VisitMatchShape(const MatchShape& binding) { +void ExprMutator::VisitBinding_(const MatchShapeNode* binding) { Expr new_value = this->VisitExpr(binding->value); Expr new_pattern = this->VisitExpr(ShapeExpr(binding->pattern)); @@ -383,19 +402,15 @@ void ExprMutator::VisitMatchShape(const MatchShape& binding) { MatchShape(new_value, Downcast(new_pattern)->values, new_var)); } -BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { - if (block.as()) { - return this->VisitDataflowBlock(Downcast(block)); - } else { - builder_->BeginBindingBlock(); - for (Binding binding : block->bindings) { - this->VisitBinding(binding); - } - return builder_->EndBlock(); +BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); } + return builder_->EndBlock(); } -BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) { +BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) { builder_->BeginDataflowBlock(); for (auto binding : block->bindings) { this->VisitBinding(binding); @@ -403,28 +418,70 @@ BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) { return builder_->EndBlock(); } -Var ExprMutator::VisitVarDef(const Var& var) { +Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { if (var->type_annotation.defined()) { Type type = this->VisitType(var->type_annotation.value()); if (!var->type_annotation.same_as(type)) { - Var new_var; - if (var.as()) { - new_var = DataflowVar(var->vid, NullOpt, type, var->span); - } else { - new_var = Var(var->vid, NullOpt, type, var->span); - } + Var new_var = DataflowVar(var->vid, NullOpt, type, var->span); new_var->shape_ = var->shape_; this->var_remap_[var->vid] = new_var; return new_var; } } - return var; + return GetRef(var); +} + +Var ExprMutator::VisitVarDef_(const VarNode* var) { + if (var->type_annotation.defined()) { + Type type = this->VisitType(var->type_annotation.value()); + if (!var->type_annotation.same_as(type)) { + Var new_var = Var(var->vid, NullOpt, type, var->span); + new_var->shape_ = var->shape_; + this->var_remap_[var->vid] = new_var; + return new_var; + } + } + return GetRef(var); } Expr ExprMutator::VisitExpr(const Expr& expr) { return builder_->Normalize(ExprFunctor::VisitExpr(expr)); } +void ExprMutator::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { + BindingBlock ret; + if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return ret; +} + +Var ExprMutator::VisitVarDef(const Var& var) { + Var ret; + if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + return ret; +} + Expr ExprMutator::VisitWithNewScope(const Expr& expr) { builder_->BeginBindingBlock(); Expr ret = this->VisitExpr(expr); @@ -467,25 +524,5 @@ Var ExprMutator::WithShapeAndType(Var var, Optional shape, Type type) return var; } -// ================== -// DataflowMutator - -void DataflowMutator::VisitBinding(const Binding& binding) { - if (binding.as()) { - VarBinding var_binding = Downcast(binding); - if (builder_->CurrentBlockIsDataFlow()) { - this->VisitDataflowVarBinding(var_binding); - } else { - ExprMutator::VisitVarBinding(var_binding); - } - } else { - ExprMutator::VisitBinding(binding); - } -} - -void DataflowMutator::VisitDataflowVarBinding(const VarBinding& binding) { - ExprMutator::VisitVarBinding(binding); -} - } // namespace relax } // namespace tvm diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index 9aad7c6953e3..f489d2338b49 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -41,7 +41,7 @@ class ToNonDFMutator : public ExprMutator { return var; } - BindingBlock VisitDataflowBlock(const DataflowBlock& block) final { + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { builder_->BeginBindingBlock(); for (Binding binding : block->bindings) { this->VisitBinding(binding); diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index e6463feeffc1..ad6ee597e31f 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -75,10 +75,10 @@ class TestToNonDataflow: @R.function def foo(x: Tensor[(m, n), "float32"]): with relax.dataflow(): - gv0 = relax.call_dps((m, n), "test.op.identity", (x,)) - gv1 = relax.call_dps((m, n), "test.op.identity", (gv0,)) - relax.output(gv1) - return gv1 + lv0 = relax.call_dps((m, n), "test.op.identity", (x,)) + gv0 = relax.call_dps((m, n), "test.op.identity", (lv0,)) + relax.output(gv0) + return gv0 mod = TestToNonDataflow @@ -90,7 +90,7 @@ def fvisit(e): old_vars.append(e) relax.analysis.post_order_visit(mod["foo"], fvisit) - _, x, _, gv0, _, gv1 = old_vars + x, lv0, gv0 = old_vars new_mod = relax.transform.ToNonDataflow()(mod) @@ -101,14 +101,14 @@ def fvisit(e): new_vars.append(e) relax.analysis.post_order_visit(new_mod["foo"], fvisit) - assert x == new_vars[1] - assert gv0 != new_vars[3] - assert isinstance(gv0, relax.DataflowVar) - assert not isinstance(new_vars[3], relax.DataflowVar) + assert x == new_vars[0] + assert lv0 != new_vars[1] + assert isinstance(lv0, relax.DataflowVar) + assert not isinstance(new_vars[1], relax.DataflowVar) - assert isinstance(gv1, relax.Var) - assert isinstance(new_vars[5], relax.Var) - assert gv1 == new_vars[5] + assert isinstance(gv0, relax.Var) + assert isinstance(new_vars[2], relax.Var) + assert gv0 == new_vars[2] def test_call_dps_rewrite():