diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index cb7761b0b385..73b0d22804bd 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -139,7 +139,7 @@ def codegen(self): """Generate the kernel library.""" self._codegen() - def optimize(self, mod, target=None, params=None): + def optimize(self, mod, target=None, target_host=None, params=None): """Helper method that optimizes a Relay module via VM. Parameters @@ -149,6 +149,11 @@ def optimize(self, mod, target=None, params=None): target : str, :any:`tvm.target.Target`, or dict of str (i.e. device/context name) to str/tvm.target.Target, optional + target_host : str or :any:`tvm.target.Target`, optional + The compilation target for host. + By default, llvm is used if it is enabled, + otherwise a stackvm intepreter is used. + params : dict of str to NDArray Input parameters to the graph that do not change during inference time. Used for constant folding. @@ -162,9 +167,10 @@ def optimize(self, mod, target=None, params=None): The parameters of the final module. """ target = self._update_target(target) + target_host = self._update_target_host(target, target_host) if params: self.set_params(params) - return self._optimize(mod, target), self.get_params() + return self._optimize(mod, target, target_host), self.get_params() def get_exec(self): """Get the VM executable. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index a98f1ef9d0ab..33854f783d45 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -806,8 +806,8 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtrOptimizeModule(args[0], args[1]); + CHECK_EQ(args.num_args, 3); + *rv = this->OptimizeModule(args[0], args[1], args[2]); }); } else { LOG(FATAL) << "Unknown packed function: " << name; @@ -835,7 +835,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe target_host_ = target_host; // Run the optimizations necessary to target the VM. - context_.module = OptimizeModule(mod, targets_); + context_.module = OptimizeModule(mod, targets_, target_host_); // Populate the global map. // @@ -923,7 +923,8 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { return transform::Sequential(pass_seqs); } -IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) { +IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets, + const Target& target_host) { Array pass_seqs; Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); @@ -988,7 +989,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(this->target_host_)); + pass_seqs.push_back(MemoryOpt(target_host)); transform::Sequential seq(pass_seqs); transform::PassContext pass_ctx = PassContext::Current(); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index d1e1f7e83cc2..b4b86d3d6d8e 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -112,7 +112,8 @@ class VMCompiler : public runtime::ModuleNode { void Codegen(); protected: - IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets); + IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets, + const Target& target_host); void PopulateGlobalMap(); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index e96d36258c67..a69f928bae58 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -593,10 +593,20 @@ def test_add_op_broadcast(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) +def test_vm_optimize_dynamic(): + dtype = 'float32' + x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype=dtype) + y = relay.var('y', shape=(relay.Any(), relay.Any()), dtype=dtype) + mod = tvm.IRModule() + mod['main'] = relay.Function([x, y], relay.add(x, y)) + comp = relay.vm.VMCompiler() + opt_mod, _ = comp.optimize(mod, target="llvm") + assert "shape_func" in opt_mod.astext(False) + def test_vm_optimize(): mod, params = testing.synthetic.get_workload() comp = relay.vm.VMCompiler() - opt_mod, _ = comp.optimize(mod, "llvm", params) + opt_mod, _ = comp.optimize(mod, target="llvm", params=params) def test_loop_free_var(): x = relay.var('x', shape=(), dtype='int32')