diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 847d21f23c47..44f8e7ea9856 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -35,6 +35,7 @@ using Expr = RelayExpr; using ExprNode = RelayExprNode; using relay::Call; using relay::CallNode; +using relay::Constant; using relay::ConstantNode; using relay::Id; using relay::If; diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 15c5617dfc01..c2e6df3aae0a 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -293,7 +293,7 @@ class ExprMutator : public ExprFunctor { } /*! - * \brief Create a new var with specified shape and type if it's original shape or type does not + * \brief Create a new var with specified shape and type if the original var's shape or type does not * match with the specified ones. * \param var The var to be updated. * \param shape The specified shape. diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index d921d3a137e6..cc3d3769da0a 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -32,38 +32,6 @@ namespace tvm { namespace relax { -/*! - * \brief Visitor to apply a function to every Expr it visits. Also applies the function - * to the shape field of the var definition site if the var's shape is a ShapeExpr. - */ -class ExprApplyVisitWithShape : public ExprVisitor { - public: - explicit ExprApplyVisitWithShape(std::function f) : f_(f) {} - - void VisitVarDef(const Var& var) { - if (var.as()) { - this->VisitExpr(Downcast(var)); - } else { - this->VisitExpr(var); - } - if (var->shape_.operator bool() && var->shape_.value().as()) { - f_(Downcast(var->shape_.value())); - } - } - - void VisitExpr(const Expr& e) final { - ExprVisitor::VisitExpr(e); - f_(e); - } - - private: - std::function f_; -}; - -void PostOrderVisitWithShape(const Expr& e, std::function fvisit) { - ExprApplyVisitWithShape(fvisit).VisitExpr(e); -} - class VMShapeLowerMutator : public ExprMutator { public: static DataType ShapeDType() { return DataType::Int(64); }; @@ -125,9 +93,7 @@ class VMShapeLowerMutator : public ExprMutator { builder_->BeginBindingBlock(); builder_->Emit(VarBinding( shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})}))); - Array params; for (Var param : node->params) { - params.push_back(this->VisitVarDef(param)); if (param->shape_.operator bool() && param->shape_.value().as()) { Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {param}), "sh"); StoreShape(shape, Downcast(param->shape_.value())->values); @@ -150,7 +116,7 @@ class VMShapeLowerMutator : public ExprMutator { blocks.push_back(builder_->EndBlock()); new_body = SeqExpr(blocks, new_body); - return Function(node->name, params, new_body, ret_type); + return Function(node->name, node->params, new_body, ret_type); } tir::PrimFunc CalculateShape(ShapeExpr s) { @@ -201,7 +167,7 @@ class VMShapeLowerMutator : public ExprMutator { } } }; - PostOrderVisitWithShape(expr, func); + PostOrderVisit(expr, func); return ret; } diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 12ff2413b8fe..65b0cab47197 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -33,7 +33,13 @@ namespace tvm { namespace relax { -void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); } +void ExprVisitor::VisitExpr_(const ConstantNode* op) { + this->VisitSpan(op->span); + + if (op->shape_) { + this->VisitExpr(Downcast(op->shape_.value())); + } +} void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); } @@ -42,20 +48,20 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) { for (Expr field : op->fields) { this->VisitExpr(field); } + + if (op->shape_) { + this->VisitExpr(Downcast(op->shape_.value())); + } } +// Visit the use-site of a defined Var void ExprVisitor::VisitExpr_(const VarNode* op) { this->VisitSpan(op->span); - if (op->type_annotation.defined()) { - this->VisitType(op->type_annotation.value()); - } } +// Visit the use-site of a defined DataflowVar void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { this->VisitSpan(op->span); - if (op->type_annotation.defined()) { - this->VisitType(op->type_annotation.value()); - } } void ExprVisitor::VisitExpr_(const FunctionNode* op) { @@ -78,6 +84,10 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { for (Expr arg : op->args) { this->VisitExpr(arg); } + + if (op->shape_) { + this->VisitExpr(Downcast(op->shape_.value())); + } } void ExprVisitor::VisitExpr_(const IfNode* op) { @@ -142,6 +152,10 @@ void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { if (var->type_annotation.defined()) { this->VisitType(var->type_annotation.value()); } + + if (var->shape_) { + this->VisitExpr(Downcast(var->shape_.value())); + } } void ExprVisitor::VisitVarDef_(const VarNode* var) { @@ -149,12 +163,14 @@ void ExprVisitor::VisitVarDef_(const VarNode* var) { if (var->type_annotation.defined()) { this->VisitType(var->type_annotation.value()); } -} -void ExprVisitor::VisitExpr(const Expr& expr) { - ExprFunctor::VisitExpr(expr); + if (var->shape_) { + this->VisitExpr(Downcast(var->shape_.value())); + } } +void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); } + void ExprVisitor::VisitBinding(const Binding& binding) { if (const auto* node = binding.as()) { VisitBinding_(node); @@ -209,23 +225,48 @@ TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr ex // ================== // ExprMutator -Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } +Expr ExprMutator::VisitExpr_(const ConstantNode* op) { + Expr new_shape; + bool unchanged = true; + if (op->shape_) { + new_shape = this->VisitExpr(Downcast(op->shape_.value())); + if (!new_shape.same_as(op->shape_)) { + unchanged = false; + } + } + + if (unchanged) { + return GetRef(op); + } else { + Expr new_constant = Constant(op->data, op->span); + new_constant->shape_ = new_shape; + return new_constant; + } +} Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); } Expr ExprMutator::VisitExpr_(const TupleNode* op) { + bool unchanged = true; tvm::Array fields; - bool all_fields_unchanged = true; for (Expr field : op->fields) { Expr new_field = this->VisitExpr(field); fields.push_back(new_field); - all_fields_unchanged &= new_field.same_as(field); + unchanged &= new_field.same_as(field); + } + + Expr new_shape; + if (op->shape_) { + new_shape = this->VisitExpr(Downcast(op->shape_.value())); + unchanged &= new_shape.same_as(op->shape_); } - if (all_fields_unchanged) { + if (unchanged) { return GetRef(op); } else { - return Tuple(fields, op->span); + Expr new_tuple = Tuple(fields, op->span); + new_tuple->shape_ = new_shape; + return new_tuple; } } @@ -288,10 +329,18 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { unchanged &= new_arg.same_as(arg); } + Expr new_shape; + if (call_node->shape_) { + new_shape = this->VisitExpr(Downcast(call_node->shape_.value())); + unchanged &= new_shape.same_as(call_node->shape_); + } + if (unchanged) { return GetRef(call_node); } else { - return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); + Expr new_call = Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); + new_call->shape_ = new_shape; + return new_call; } } @@ -424,29 +473,75 @@ BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) { } Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { + bool type_unchanged = true; + Type new_type; if (var->type_annotation.defined()) { - Type type = this->VisitType(var->type_annotation.value()); - if (!var->type_annotation.same_as(type)) { - Var new_var = DataflowVar(var->vid, NullOpt, type, var->span); + new_type = this->VisitType(var->type_annotation.value()); + type_unchanged &= new_type.same_as(var->type_annotation); + } + + bool shape_unchanged = true; + Expr new_shape; + if (var->shape_) { + new_shape = this->VisitExpr(Downcast(var->shape_.value())); + shape_unchanged &= new_shape.same_as(var->shape_); + } + + if (type_unchanged && shape_unchanged) { + return GetRef(var); + } else { + Var new_var; + if (type_unchanged) { + new_var = DataflowVar(var->vid, NullOpt, var->type_annotation, var->span); + } else { + new_var = DataflowVar(var->vid, NullOpt, new_type, var->span); + } + + if (shape_unchanged) { new_var->shape_ = var->shape_; - this->var_remap_[var->vid] = new_var; - return new_var; + } else { + new_var->shape_ = new_shape; } + + this->var_remap_[var->vid] = new_var; + return new_var; } - return GetRef(var); } Var ExprMutator::VisitVarDef_(const VarNode* var) { + bool type_unchanged = true; + Type new_type; if (var->type_annotation.defined()) { - Type type = this->VisitType(var->type_annotation.value()); - if (!var->type_annotation.same_as(type)) { - Var new_var = Var(var->vid, NullOpt, type, var->span); + new_type = this->VisitType(var->type_annotation.value()); + type_unchanged &= new_type.same_as(var->type_annotation); + } + + bool shape_unchanged = true; + Expr new_shape; + if (var->shape_) { + new_shape = this->VisitExpr(Downcast(var->shape_.value())); + shape_unchanged &= new_shape.same_as(var->shape_); + } + + if (type_unchanged && shape_unchanged) { + return GetRef(var); + } else { + Var new_var; + if (type_unchanged) { + new_var = Var(var->vid, NullOpt, var->type_annotation, var->span); + } else { + new_var = Var(var->vid, NullOpt, new_type, var->span); + } + + if (shape_unchanged) { new_var->shape_ = var->shape_; - this->var_remap_[var->vid] = new_var; - return new_var; + } else { + new_var->shape_ = new_shape; } + + this->var_remap_[var->vid] = new_var; + return new_var; } - return GetRef(var); } Expr ExprMutator::VisitExpr(const Expr& expr) { diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index abd6fdb0f99e..ddec2fe929f8 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -53,9 +53,7 @@ def test_fma_rewrite(): assert structural_equal(gv0.shape, relax.ShapeExpr([m, n])) # after rewrite - passes = [relax.transform.FMARewrite()] - seq = tvm.transform.Sequential(passes) - new_mod = seq(mod) + new_mod = relax.transform.FMARewrite()(mod) func = new_mod["main"] v1 = func.body.blocks[0].bindings[1].var s1 = func.body.blocks[0].bindings[1].value @@ -69,6 +67,31 @@ def test_fma_rewrite(): assert gv0 == v0 assert type(func.body.blocks[0].bindings[1].var) == relax.Var +def test_visit_shape(): + @tvm.script.ir_module + class TestVisitShape: + @R.function + def foo(x: Tensor[(m, n), "float32"]): + gv0 = R.add(x, x) + return gv0 + + mod = TestVisitShape + + shape_expr = [] + def fvisit(e): + if isinstance(e, relax.ShapeExpr): + nonlocal shape_expr + shape_expr.append(e) + + relax.analysis.post_order_visit(mod["foo"], fvisit) + + # should have visited ShapeExpr 3 times + # the first time being visited is x.shape + # the last two times are the call node's shape and gv0's shape + assert len(shape_expr) == 3 + assert shape_expr[0] == mod["foo"].params[0].shape + assert shape_expr[1] == shape_expr[2] + def test_to_non_dataflow(): @tvm.script.ir_module @@ -312,6 +335,7 @@ def foo(x: Tensor[(m, n), "float32"]): if __name__ == "__main__": test_fma_rewrite() + test_visit_shape() test_to_non_dataflow() test_call_dps_rewrite() test_vm_memory_lower()