diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 31009008b23c..f1cdefc3ed3a 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -23,6 +23,7 @@ import numpy as np import tvm +import tvm.ndarray as _nd from tvm import autotvm, container from tvm.object import Object from tvm.relay import expr as _expr @@ -409,6 +410,8 @@ def __init__(self): self._codegen = self.mod["codegen"] self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] + self._get_params_func = self.mod["get_params"] + self._optimize = self.mod["optimize"] def set_params(self, params): """Set constant parameters for the model. @@ -426,6 +429,14 @@ def set_params(self, params): inputs[name] = _expr.const(param) self._set_params_func(inputs) + def get_params(self): + """Return the updated weights.""" + params = self._get_params_func() + ret = {} + for key, value in params.items(): + ret[key] = value.data + return ret + def lower(self, mod, target=None, target_host=None): """Lower the module to VM bytecode. @@ -458,6 +469,33 @@ def codegen(self): """Generate the kernel library.""" self._codegen() + def optimize(self, mod, target=None, params=None): + """Helper method that optimizes a Relay module via VM. + + Parameters + ---------- + mod : relay.Module + + target : str, :any:`tvm.target.Target`, or dict of str (i.e. + device/context name) to str/tvm.target.Target, optional + + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. + + Returns + ------- + mod : relay.Module + The optimized relay module. + + params : dict + The parameters of the final module. + """ + target = self._update_target(target) + if params: + self.set_params(params) + return self._optimize(mod, target), self.get_params() + def get_exec(self): """Get the VM executable. diff --git a/python/tvm/relay/scope_builder.py b/python/tvm/relay/scope_builder.py index 16044c127e98..43c653203c81 100644 --- a/python/tvm/relay/scope_builder.py +++ b/python/tvm/relay/scope_builder.py @@ -18,6 +18,7 @@ """The scope builder interface.""" from __future__ import absolute_import +from . import ty as _ty from . import expr as _expr from .._ffi import base as _base diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index cc5d6bc02a47..8d4f4addaca9 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -772,6 +772,19 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, this->SetParam(kv.first, kv.second->data); } }); + } else if (name == "get_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Map ret; + for (const auto& kv : params_) { + ret.Set(kv.first, ConstantNode::make(kv.second)); + } + *rv = ret; + }); + } else if (name == "optimize") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 2); + *rv = this->OptimizeModule(args[0], args[1]); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index d4a7a1a25689..9ea939ce9c83 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -22,6 +22,7 @@ from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.testing.config import ctx_list from tvm.relay.prelude import Prelude +from tvm.relay import testing import pytest def check_result(args, expected_result, mod=None): @@ -570,6 +571,10 @@ def test_add_op_broadcast(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) +def test_vm_optimize(): + mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18) + comp = relay.backend.vm.VMCompiler() + opt_mod, _ = comp.optimize(mod, "llvm", params) if __name__ == "__main__": pytest.main([__file__])