From fb0d7265c0e0e11614fd9783b70d219f336a450d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 3 Aug 2020 09:36:41 -0700 Subject: [PATCH] [FoldConstant] Create Interpreter for each constant subgraph (#6195) --- src/relay/transforms/fold_constant.cc | 36 +++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index b077a8ae206b..0ecbfea8c905 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -77,9 +77,8 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec // or make a more powerful partial evaluator. class ConstantFolder : public ExprMutator { public: - explicit ConstantFolder(FInterpreter executor, IRModule module) - : executor_(executor), - module_(module), + explicit ConstantFolder(IRModule module) + : module_(module), shape_of_op_(Op::Get("shape_of")), vm_shape_of_op_(Op::Get("vm.shape_of")), invoke_tvm_op_(Op::Get("vm.invoke_tvm_op")), @@ -163,8 +162,6 @@ class ConstantFolder : public ExprMutator { } private: - // Internal interepreter. - FInterpreter executor_; // Internal constant checker ConstantChecker checker_; // Module @@ -180,6 +177,20 @@ class ConstantFolder : public ExprMutator { const Op& cast_op_; const Op& ndarray_size_op_; + // Create an interpreter. + FInterpreter GetInterpreter(const IRModule& mod) { + using tvm::transform::PassContext; + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + Target target = Target::Create("llvm"); + // use a fresh build context + // in case we are already in a build context. + With fresh_build_ctx(PassContext::Create()); + + return CreateInterpreter(mod, ctx, target); + } + // Convert value to expression. Expr ObjectToExpr(const ObjectRef& value) { if (value->IsInstance()) { @@ -218,7 +229,9 @@ class ConstantFolder : public ExprMutator { mod = seq(mod); auto entry_func = Downcast(mod->Lookup("main")); expr = expr.as() == nullptr ? entry_func->body : entry_func; - return ObjectToExpr(executor_(expr)); + + FInterpreter executor = GetInterpreter(mod); + return ObjectToExpr(executor(expr)); } // Evaluate a call to the shape_of operator for tensors with constant @@ -331,16 +344,7 @@ class ConstantFolder : public ExprMutator { }; Expr FoldConstant(const Expr& expr, const IRModule& mod) { - using tvm::transform::PassContext; - DLContext ctx; - ctx.device_type = kDLCPU; - ctx.device_id = 0; - Target target = Target::Create("llvm"); - // use a fresh build context - // in case we are already in a build context. - With fresh_build_ctx(PassContext::Create()); - - return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr); + return ConstantFolder(mod).Mutate(expr); } namespace transform {