diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h new file mode 100644 index 000000000000..596905ae9d90 --- /dev/null +++ b/include/tvm/relax/backend.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/backend.h + * \brief Relax backend specific transformation passes. + */ +#ifndef TVM_RELAX_BACKEND_H_ +#define TVM_RELAX_BACKEND_H_ + +#include + +namespace tvm { +namespace relax { +namespace transform { + +/*! + * \brief Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics. + * + * \return The Pass. + */ +TVM_DLL Pass VMMemoryLower(); + +/*! + * \brief Lower the shape expression in relax to VM shape heap and TIR functions. + * + * \return The Pass. + */ +TVM_DLL Pass VMShapeLower(); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index a20414a7a672..847d21f23c47 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -132,7 +132,7 @@ class Var : public Expr { TVM_DLL explicit Var(Id vid, runtime::Optional shape_annotation, runtime::Optional type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); }; /*! \brief A sub-type of the variable node used to mark dataflow variables from @@ -414,11 +414,11 @@ class FunctionNode : public BaseFuncNode { TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); }; -class Function : public Expr { +class Function : public BaseFunc { public: TVM_DLL explicit Function(runtime::Optional name, Array params, Expr body, Type ret_type, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); + TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); }; @@ -445,10 +445,10 @@ class ExternFuncNode : public BaseFuncNode { TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode); }; -class ExternFunc : public Expr { +class ExternFunc : public BaseFunc { public: TVM_DLL ExternFunc(String global_symbol, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode); + TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); }; diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h new file mode 100644 index 000000000000..29af6f6361da --- /dev/null +++ b/include/tvm/relax/transform.h @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform.h + * \brief Relax specific transformation passes. + */ +#ifndef TVM_RELAX_TRANSFORM_H_ +#define TVM_RELAX_TRANSFORM_H_ + +#include +#include + +namespace tvm { +namespace relax { +namespace transform { + +using Pass = tvm::transform::Pass; +using PassInfo = tvm::transform::PassInfo; +using PassContext = tvm::transform::PassContext; +using Function = tvm::relax::Function; + +/*! + * \brief Create a function pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the function pass. + * \param name The name of the function pass. + * \param required The list of the passes that the function pass is dependent on. + * + * \return The created function pass. + */ +TVM_DLL Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required); + +/*! + * \brief Perform fused multiply add rewriting in dataflow blocks. + * + * \return The Pass. + */ +TVM_DLL Pass FMARewrite(); + +/*! + * \brief Transform all dataflow structure to non-dataflow version. + * + * \return The Pass. + */ +TVM_DLL Pass ToNonDataflow(); + +/*! + * \brief Perform explicit tensor allocation for call_dps. + * + * \return The Pass. + */ +TVM_DLL Pass CallDPSRewrite(); + +/*! + * \brief Transform Relax IR to A-normal form. + * + * \return The Pass. + */ +TVM_DLL Pass ToANF(); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRANSFORM_H_ diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index e52631314666..8de386dfccbb 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -16,64 +16,71 @@ # under the License. # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck -from tvm import IRModule +import tvm.ir from . import _ffi_api +@tvm._ffi.register_object("relax.FunctionPass") +class FunctionPass(tvm.ir.transform.Pass): + """A pass that works on each tvm.relax.Function in a module. A function + pass class should be created through `function_pass`. + """ -def fma_rewrite(expr): +def FMARewrite() -> tvm.transform.Pass: """Perform fused multiply add rewriting in dataflow blocks. - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. + Returns + ------- + ret: tvm.transform.Pass """ - return _ffi_api.fma_rewrite(expr) + return _ffi_api.FMARewrite() -def to_non_dataflow(mod: IRModule) -> IRModule: +def ToNonDataflow() -> tvm.transform.Pass: """Transform all dataflow structure to non-dataflow version. - Parameters - ---------- - mod : tvm.IRModule - The input module. + Returns + ------- + ret: tvm.transform.Pass """ - return _ffi_api.to_non_dataflow(mod) + return _ffi_api.ToNonDataflow() -def call_dps_rewrite(mod: IRModule) -> IRModule: +def CallDPSRewrite() -> tvm.transform.Pass: """Perform explicit tensor allocation for call_dps. - Parameters - ---------- - mod : tvm.IRModule - The input module. + Returns + ------- + ret: tvm.transform.Pass """ - return _ffi_api.call_dps_rewrite(mod) + return _ffi_api.CallDPSRewrite() -def vm_memory_lower(mod: IRModule) -> IRModule: +def VMMemoryLower() -> tvm.transform.Pass: """Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics. - Parameters - ---------- - mod : tvm.IRModule - The input module. + Returns + ------- + ret: tvm.transform.Pass """ - return _ffi_api.vm_memory_lower(mod) + return _ffi_api.VMMemoryLower() -def vm_shape_lower(mod: IRModule) -> IRModule: - """Lower the shape expression in relax to VM shape heap and TIR functions. +def VMShapeLower() -> tvm.transform.Pass: + """Lower the shape expressions in relax to VM shape heap manipulations and generate related + TIR functions to do shape calculations. - Parameters - ---------- - mod : tvm.IRModule - The input module. + Returns + ------- + ret: tvm.transform.Pass """ - return _ffi_api.vm_shape_lower(mod) + return _ffi_api.VMShapeLower() + +def ToANF() -> tvm.transform.Pass: + """Transforming Relax IR to A-normal form. -def to_anf(mod: IRModule): - return _ffi_api.to_anf(mod) + Returns + ------- + ret: tvm.transform.Pass + """ + return _ffi_api.ToANF() diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index ffe7b232d309..e4fbff41950a 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -168,10 +168,12 @@ def build(mod: tvm.IRModule, lib: tvm.runtime.Module A runtime module that contains generated code. """ - new_mod = transform.to_non_dataflow(mod) - new_mod = transform.call_dps_rewrite(new_mod) - new_mod = transform.vm_memory_lower(new_mod) - new_mod = transform.vm_shape_lower(new_mod) + passes = [relax.transform.ToNonDataflow()] + passes.append(relax.transform.CallDPSRewrite()) + passes.append(relax.transform.VMMemoryLower()) + passes.append(relax.transform.VMShapeLower()) + seq = tvm.transform.Sequential(passes) + new_mod = seq(mod) # split primfunc and relax function rx_mod, tir_mod = _split_tir_relax(new_mod) @@ -189,5 +191,5 @@ def _split_tir_relax(mod: tvm.IRModule) -> Tuple[tvm.IRModule, tvm.IRModule]: elif isinstance(mod[gv], relax.Function): rx_mod[gv] = mod[gv] else: - raise ValueError("An IRModule should contain contain relax function and TIR primfunc.") - return rx_mod, tir_mod \ No newline at end of file + raise ValueError("An IRModule should contain relax function and/or TIR primfunc.") + return rx_mod, tir_mod diff --git a/src/relax/backend/vm/vm_memory_lower.cc b/src/relax/backend/vm/vm_memory_lower.cc index 269ca141b0ff..b8ac3c47d728 100644 --- a/src/relax/backend/vm/vm_memory_lower.cc +++ b/src/relax/backend/vm/vm_memory_lower.cc @@ -18,9 +18,10 @@ */ /*! * \file src/relax/backend/vm/vm_memory_lower.cc - * \brief + * \brief Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics. */ #include +#include #include #include #include @@ -29,7 +30,6 @@ namespace tvm { namespace relax { -namespace vm { // ================== // MemLowerMutator @@ -37,25 +37,11 @@ namespace vm { // Example: // x = relax.builtin.alloc_tensor((m, n)) // --> -// gv0 = relax.call_packed("relax.vm.builtin.alloc_storage", (m * n), relax.attrs.AllocStorageAttrs) -// gv1 = relax.call_packed("relax.vm.builtin.alloc_tensor", gv0, (m, n), relax.attrs.AllocTensorAttrs) +// gv0 = relax.call_packed("relax.vm.builtin.alloc_storage", (m * n), relax.attrs.AllocStorageAttrs) +// gv1 = relax.call_packed("relax.vm.builtin.alloc_tensor", gv0, (m, n), +// relax.attrs.AllocTensorAttrs) class VMMemLowerMutator : public ExprMutator { - public: - explicit VMMemLowerMutator(IRModule mod) { mod_ = mod; } - - IRModule Lower() { - IRModule ret_mod = IRModule(); - for (auto& p : mod_->functions) { - Expr func = p.second; - if (p.second->IsInstance()) { - func = this->VisitExpr(p.second); - } - ret_mod->Add(p.first, Downcast(func)); - } - return ret_mod; - } - Expr ComputeStorageSize(const Expr& shape, const Type& type) const { DynTensorType tensor_type = Downcast(type); DataType dtype = DataType(tensor_type->dtype); @@ -101,27 +87,33 @@ class VMMemLowerMutator : public ExprMutator { storage_attr->dtype = DataType::Float(32); storage_attr->device_type = 1; - Var storage = builder_->Emit(Call(vm_alloc_storage_op, {storage_size}, Attrs(storage_attr)), "storage"); + Var storage = + builder_->Emit(Call(vm_alloc_storage_op, {storage_size}, Attrs(storage_attr)), "storage"); auto tensor_attr = make_object(); tensor_attr->offset = 0; tensor_attr->dtype = DataType::Float(32); Expr shape = call->args[0]; - Var tensor = builder_->Emit(Call(vm_alloc_tensor_op, {storage, shape}, Attrs(tensor_attr)), "tensor"); + Var tensor = + builder_->Emit(Call(vm_alloc_tensor_op, {storage, shape}, Attrs(tensor_attr)), "tensor"); return tensor; } return GetRef(call); } - - private: - IRModule mod_; }; -TVM_REGISTER_GLOBAL("relax.transform.vm_memory_lower") -.set_body_typed([](IRModule mod) { - return VMMemLowerMutator(mod).Lower(); -}); +Expr VMMemLower(const Expr& e) { return VMMemLowerMutator().VisitExpr(e); } + +namespace transform { + +Pass VMMemoryLower() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(VMMemLower(f)); }; + return CreateFunctionPass(pass_func, 0, "VMMemoryLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMMemoryLower").set_body_typed(VMMemoryLower); -} // namespace vm +} // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index bbd2ed2c6d07..38ec2e8ae4be 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -18,24 +18,23 @@ */ /*! * \file src/relax/backend/vm/vm_shape_lower.cc - * \brief + * \brief Lower the shape expressions in relax to VM shape heap manipulations and generate related + * TIR functions to do shape calculations. */ +#include +#include #include #include #include #include #include -#include namespace tvm { namespace relax { -namespace vm { class VMShapeLowerMutator : public ExprMutator { public: - static DataType ShapeDType() { - return DataType::Int(64); - }; + static DataType ShapeDType() { return DataType::Int(64); }; explicit VMShapeLowerMutator(IRModule mod) { mod_ = mod; } @@ -199,11 +198,16 @@ class VMShapeLowerMutator : public ExprMutator { Map expr2slot_; }; -TVM_REGISTER_GLOBAL("relax.transform.vm_shape_lower") -.set_body_typed([](IRModule mod) { - return VMShapeLowerMutator(mod).Lower(); -}); +namespace transform { + +Pass VMShapeLower() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return VMShapeLowerMutator(mod).Lower(); }; + return CreateModulePass(pass_func, 0, "VMShapeLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed(VMShapeLower); -} // namespace vm +} // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc new file mode 100644 index 000000000000..69d00bdc340f --- /dev/null +++ b/src/relax/ir/transform.cc @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/ir/transform.cc + * \brief Relax specific transformation passes. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace transform { + +TVM_REGISTER_PASS_CONFIG_OPTION("relax.fallback_device_type", IntImm); + +// TODO(@yuchen): will need to dedup with FunctionPass in Relay when we upstream +class FunctionPass; + +/*! + * \brief Function-level passes are used to implement various global + * optimizations for a given Relax module. It fetches one function at a time + * from the function list in the module for optimization. + * + * Note that the scope of passes at this level is a Relax function. Therefore, + * we cannot add or delete a function through these passes as they are not aware + * of the global information. + */ +class FunctionPassNode : public tvm::transform::PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relax function as a + * `pass_func` and let it run on a given module. The same `pass_func` will + * then be applied on each function in the module. + */ + runtime::TypedPackedFunc pass_func; + + FunctionPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + + /*! + * \brief Run a function pass on given pass context. + * + * \param mod The module that an optimization pass is applied on. + * \param mod The context that an optimization pass executes on. + * + * \return Return the updated module. + */ + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; + + /*! + * \brief Get the pass information/meta data. + */ + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "relax.FunctionPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); + + private: + /* + * \brief Check if a function should be skipped for optimization. + * + * \param func The target function to be checked. + * + * \return Return true if the function will be skipped, otherwise false. + */ + bool SkipFunction(const Function& func) const; +}; + +class FunctionPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL FunctionPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); +}; + +FunctionPass::FunctionPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + auto n = make_object(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform Module -> Module optimizations at the Function level. +IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + const PassInfo& pass_info = Info(); + + ICHECK(mod.defined()); + + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing function pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod); + + IRModule updated_mod = mod->ShallowCopy(); + + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relax::Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + + VLOG(1) << "Output module:" << std::endl << PrettyPrint(updated_mod); + + return updated_mod; +} + +bool FunctionPassNode::SkipFunction(const Function& func) const { + // TODO(@yuchen): will need to revisit in the future + return (func->GetAttr(relay::attr::kCompiler).defined()) || + func->GetAttr(relay::attr::kSkipOptimization, 0) != 0; +} + +Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required) { + PassInfo pass_info = PassInfo(opt_level, name, required); + return FunctionPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(FunctionPassNode); + +TVM_REGISTER_GLOBAL("relax._transform.MakeFunctionPass") + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return FunctionPass(pass_func, pass_info); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name << " at the optimization level " + << info->opt_level; + }); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/call_dps_rewrite.cc b/src/relax/transform/call_dps_rewrite.cc index 226b92c0169b..79c7d89fe0b2 100644 --- a/src/relax/transform/call_dps_rewrite.cc +++ b/src/relax/transform/call_dps_rewrite.cc @@ -18,10 +18,11 @@ */ /*! * \file src/relax/transform/call_dps_rewrite.cc - * \brief + * \brief Perform explicit tensor allocation for call_dps. */ #include #include +#include #include #include @@ -41,20 +42,6 @@ namespace relax { class CallDPSMutator : public ExprMutator { public: - explicit CallDPSMutator(IRModule mod) { mod_ = mod; } - - IRModule Lower() { - IRModule ret_mod = IRModule(); - for (auto& p : mod_->functions) { - Expr func = p.second; - if (p.second->IsInstance()) { - func = this->VisitExpr(p.second); - } - ret_mod->Add(p.first, Downcast(func)); - } - return ret_mod; - } - Expr VisitExpr_(const CallNode* call) override { // post-order mutation Expr expr = VisitExprPostOrder_(call); @@ -65,21 +52,35 @@ class CallDPSMutator : public ExprMutator { if (call->op == call_dps_op) { ShapeExpr output_shape = Downcast(call->args[0]); - Var tensor = builder_->Emit(Call(alloc_tensor_op, {output_shape}), "tensor"); - builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_"); + Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc"); + Array args; + if (call->args[2].as()) { + args = Downcast(call->args[2])->fields; + args.push_back(tensor); + builder_->Emit(Call(call->args[1], args), "_"); + } else { + builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_"); + } return tensor; } return GetRef(call); } - - private: - IRModule mod_; }; -TVM_REGISTER_GLOBAL("relax.transform.call_dps_rewrite").set_body_typed([](IRModule mod) { - return CallDPSMutator(mod).Lower(); -}); +Expr CallDPSRewrite(const Expr& e) { return CallDPSMutator().VisitExpr(e); } + +namespace transform { + +Pass CallDPSRewrite() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(CallDPSRewrite(f)); }; + return CreateFunctionPass(pass_func, 0, "CallDPSRewrite", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.CallDPSRewrite").set_body_typed(CallDPSRewrite); + +} // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/fma_rewrite.cc b/src/relax/transform/fma_rewrite.cc index 15f23a6cb534..546c42e20670 100644 --- a/src/relax/transform/fma_rewrite.cc +++ b/src/relax/transform/fma_rewrite.cc @@ -18,9 +18,10 @@ */ /*! * \file src/relax/transform/fma_rewrite.cc - * \brief + * \brief Perform fused multiply add rewriting in dataflow blocks. */ #include +#include namespace tvm { namespace relax { @@ -33,7 +34,7 @@ namespace relax { // --> // z0 = ewise_fma(a, b, c) -// Example 2: +// Example 2: // Question: do we want to support this? // x0 = mul(a, add(k, b)) // z0 = add(x0, c) @@ -65,14 +66,19 @@ class EwiseFMARewriter : public ExprMutator { } }; -Expr FMARewrite(const Expr& e) { - return EwiseFMARewriter().VisitExpr(e); +Expr FMARewrite(const Expr& e) { return EwiseFMARewriter().VisitExpr(e); } + +namespace transform { + +Pass FMARewrite() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(FMARewrite(f)); }; + return CreateFunctionPass(pass_func, 2, "FMARewrite", {}); } -TVM_REGISTER_GLOBAL("relax.transform.fma_rewrite") -.set_body_typed([](Expr expr) { - return FMARewrite(expr); -}); +TVM_REGISTER_GLOBAL("relax.transform.FMARewrite").set_body_typed(FMARewrite); + +} // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/to_anf.cc b/src/relax/transform/to_anf.cc index abcdc2166fff..1d3814fa49c0 100644 --- a/src/relax/transform/to_anf.cc +++ b/src/relax/transform/to_anf.cc @@ -24,36 +24,27 @@ #include #include -#include +#include namespace tvm { namespace relax { - // TODO(@altanh): LCA binding lifting -class ToANFMutator : public ExprMutator { - public: - ToANFMutator(const IRModule& mod) : mod_(mod) {} - - IRModule Lower() { - IRModule ret_mod = IRModule(); - for (auto& p : mod_->functions) { - Expr func = p.second; - if (p.second->IsInstance()) { - func = this->VisitExpr(p.second); - } - ret_mod->Add(p.first, Downcast(func)); - } - return ret_mod; - } - - private: - IRModule mod_; -}; - -TVM_REGISTER_GLOBAL("relax.transform.to_anf").set_body_typed([](IRModule mod) { - return ToANFMutator(mod).Lower(); -}); +class ToANFMutator : public ExprMutator {}; + +Expr ToANF(const Expr& e) { return ToANFMutator().VisitExpr(e); } + +namespace transform { + +Pass ToANF() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(ToANF(f)); }; + return CreateFunctionPass(pass_func, 1, "ToANF", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ToANF").set_body_typed(ToANF); + +} // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index 3f58f6e5dc1d..9aad7c6953e3 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -18,10 +18,11 @@ */ /*! * \file src/relax/transform/to_non_dataflow.cc - * \brief + * \brief Transform all dataflow structure to non-dataflow version. */ #include #include +#include #include #include @@ -30,22 +31,8 @@ namespace relax { class ToNonDFMutator : public ExprMutator { public: - explicit ToNonDFMutator(IRModule mod) { mod_ = mod; } - - IRModule Lower() { - IRModule ret_mod = IRModule(); - for (auto& p : mod_->functions) { - Expr func = p.second; - if (p.second->IsInstance()) { - func = this->VisitExpr(p.second); - } - ret_mod->Add(p.first, Downcast(func)); - } - return ret_mod; - } - Var VisitVarDef(const Var& var) final { - if (var.as()){ + if (var.as()) { Var new_var = Var(var->vid, NullOpt, var->checked_type_, var->span); new_var->shape_ = var->shape_; this->var_remap_[var->vid] = new_var; @@ -61,14 +48,21 @@ class ToNonDFMutator : public ExprMutator { } return builder_->EndBlock(); } - - private: - IRModule mod_; }; -TVM_REGISTER_GLOBAL("relax.transform.to_non_dataflow").set_body_typed([](IRModule mod) { - return ToNonDFMutator(mod).Lower(); -}); +Expr ToNonDataflow(const Expr& e) { return ToNonDFMutator().VisitExpr(e); } + +namespace transform { + +Pass ToNonDataflow() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(ToNonDataflow(f)); }; + return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); + +} // namespace transform } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 6375ccae28fb..e6463feeffc1 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -17,9 +17,10 @@ from __future__ import annotations # must import to defer parsing of annotations import tvm -from tvm import tir from tvm import relax +from tvm import tir from tvm.ir import structural_equal +from tvm.ir.module import IRModule import tvm.script from tvm.script import relax as R @@ -39,6 +40,7 @@ def test_fma_rewrite(): gv0 = ib.emit_output(relax.op.add(lv0, y)) ib.emit_func_output(gv0) expr = ib.get() + mod = IRModule.from_expr(expr) # before rewrite v0 = expr.body.blocks[0].bindings[1].var @@ -50,7 +52,10 @@ def test_fma_rewrite(): assert structural_equal(gv0.shape, relax.ShapeExpr([m, n])) # after rewrite - func = relax.transform.fma_rewrite(expr) + passes = [relax.transform.FMARewrite()] + seq = tvm.transform.Sequential(passes) + new_mod = seq(mod) + func = new_mod["main"] v1 = func.body.blocks[0].bindings[1].var s1 = func.body.blocks[0].bindings[1].value assert isinstance(s1, tvm.relay.Call) @@ -87,7 +92,7 @@ def fvisit(e): relax.analysis.post_order_visit(mod["foo"], fvisit) _, x, _, gv0, _, gv1 = old_vars - new_mod = relax.transform.to_non_dataflow(mod) + new_mod = relax.transform.ToNonDataflow()(mod) new_vars = [] def fvisit(e): @@ -108,13 +113,13 @@ def fvisit(e): def test_call_dps_rewrite(): @tvm.script.ir_module - class TestCallDpsRewrite: + class TestCallDPSRewrite: @R.function def foo(x: Tensor[(m, n), "float32"]): gv0 = relax.call_dps((m, n), "test.op.identity", (x,)) return gv0 - mod = TestCallDpsRewrite + mod = TestCallDPSRewrite # before rewrite v0 = mod["foo"].body.blocks[0].bindings[0].var @@ -123,7 +128,7 @@ def foo(x: Tensor[(m, n), "float32"]): assert s0.op.name == "relax.call_dps" # after rewrite - new_mod = relax.transform.call_dps_rewrite(mod) + new_mod = relax.transform.CallDPSRewrite()(mod) func = new_mod["foo"] block = func.body.blocks[0] @@ -151,7 +156,7 @@ def foo(x: Tensor[(m, n), "float32"]): mod = TestVMMemoryLower # after vm memory lowering - new_mod = relax.transform.vm_memory_lower(mod) + new_mod = relax.transform.VMMemoryLower()(mod) func = new_mod["foo"] assert isinstance(new_mod, tvm.IRModule) @@ -181,7 +186,7 @@ def foo(x: Tensor[_, "float32"]) -> Shape: mod = TestVMShapeLower # after vm shape lowering - new_mod = relax.transform.vm_shape_lower(mod) + new_mod = relax.transform.VMShapeLower()(mod) assert isinstance(new_mod, tvm.IRModule) assert isinstance(new_mod["shape_func"], tvm.tir.function.PrimFunc) @@ -214,7 +219,7 @@ def test_to_anf(): func = relax.Function([x], body, None, gvar) mod: tvm.IRModule = tvm.IRModule({gvar: func}) - mod = relax.transform.to_anf(mod) + new_mod = relax.transform.ToANF()(mod) @tvm.script.ir_module class TestToANFExpected: @@ -226,7 +231,7 @@ def f(x: Tensor[_, "float32"]): return (gv, gv2) # TODO(@altanh): fix this once type inference works properly...? - assert R.parser.astext(mod) == R.parser.astext(TestToANFExpected) + assert R.parser.astext(new_mod) == R.parser.astext(TestToANFExpected) @@ -236,3 +241,4 @@ def f(x: Tensor[_, "float32"]): test_call_dps_rewrite() test_vm_memory_lower() test_vm_shape_lowering() + test_to_anf()