From e2ce404178e7cdaad898fe12b6141b60d3133e19 Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Tue, 26 Feb 2019 21:15:39 -0800 Subject: [PATCH] [Relay] Port param dict save/load from NNVM (#2620) --- python/tvm/api.py | 2 +- python/tvm/relay/__init__.py | 5 ++ python/tvm/relay/param_dict.py | 60 ++++++++++++++++++ src/relay/backend/interpreter.cc | 5 ++ src/relay/backend/param_dict.cc | 87 +++++++++++++++++++++++++++ src/relay/backend/param_dict.h | 43 +++++++++++++ tests/python/relay/test_param_dict.py | 78 ++++++++++++++++++++++++ 7 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 python/tvm/relay/param_dict.py create mode 100644 src/relay/backend/param_dict.cc create mode 100644 src/relay/backend/param_dict.h create mode 100644 tests/python/relay/test_param_dict.py diff --git a/python/tvm/api.py b/python/tvm/api.py index 10a97171e58f..514490ae83ea 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -136,7 +136,7 @@ def load_json(json_str): def save_json(node): - """Load tvm object as json string. + """Save tvm object as json string. Parameters ---------- diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index fe00877c0fb0..6d44d07f4bbf 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -13,6 +13,7 @@ from . import prelude from . import parser from . import debug +from . import param_dict # Root operators from .op import Op @@ -85,3 +86,7 @@ # Parser fromtext = parser.fromtext + +# Param Serialization +save_param_dict = param_dict.save_param_dict +load_param_dict = param_dict.load_param_dict diff --git a/python/tvm/relay/param_dict.py b/python/tvm/relay/param_dict.py new file mode 100644 index 000000000000..f7647beadeb2 --- /dev/null +++ b/python/tvm/relay/param_dict.py @@ -0,0 +1,60 @@ +# pylint: disable=invalid-name +"""Helper utility to save parameter dicts.""" +import tvm + +_save_param_dict = tvm.get_global_func("tvm.relay._save_param_dict") +_load_param_dict = tvm.get_global_func("tvm.relay._load_param_dict") + +def save_param_dict(params): + """Save parameter dictionary to binary bytes. + + The result binary bytes can be loaded by the + GraphModule with API "load_params". + + Parameters + ---------- + params : dict of str to NDArray + The parameter dictionary. + + Returns + ------- + param_bytes: bytearray + Serialized parameters. + + Examples + -------- + .. code-block:: python + + # compile and save the modules to file. + graph, lib, params = tvm.relay.build(func, target=target, params=params) + module = graph_runtime.create(graph, lib, tvm.gpu(0)) + # save the parameters as byte array + param_bytes = tvm.relay.save_param_dict(params) + # We can serialize the param_bytes and load it back later. + # Pass in byte array to module to directly set parameters + module.load_params(param_bytes) + """ + args = [] + for k, v in params.items(): + args.append(k) + args.append(tvm.nd.array(v)) + return _save_param_dict(*args) + + +def load_param_dict(param_bytes): + """Load parameter dictionary to binary bytes. + + Parameters + ---------- + param_bytes: bytearray + Serialized parameters. + + Returns + ------- + params : dict of str to NDArray + The parameter dictionary. + """ + if isinstance(param_bytes, (bytes, str)): + param_bytes = bytearray(param_bytes) + load_arr = _load_param_dict(param_bytes) + return {v.name : v.array for v in load_arr} diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 4ef893f463e9..3128d2a71159 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -578,5 +578,10 @@ TVM_REGISTER_API("relay.backend.CreateInterpreter") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = CreateInterpreter(args[0], args[1], args[2]); }); + +TVM_REGISTER_NODE_TYPE(ClosureNode); +TVM_REGISTER_NODE_TYPE(TupleValueNode); +TVM_REGISTER_NODE_TYPE(TensorValueNode); + } // namespace relay } // namespace tvm diff --git a/src/relay/backend/param_dict.cc b/src/relay/backend/param_dict.cc new file mode 100644 index 000000000000..87d3dd373e83 --- /dev/null +++ b/src/relay/backend/param_dict.cc @@ -0,0 +1,87 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file param_dict.cc + * \brief Implementation and registration of parameter dictionary + * serializing/deserializing functions. + */ +#include "param_dict.h" + +#include + +#include +#include + +namespace tvm { +namespace relay { + +using namespace runtime; + +TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict") +.set_body([](TVMArgs args, TVMRetValue *rv) { + CHECK_EQ(args.size() % 2, 0u); + // `args` is in the form "key, value, key, value, ..." + size_t num_params = args.size() / 2; + std::vector names; + names.reserve(num_params); + std::vector arrays; + arrays.reserve(num_params); + for (size_t i = 0; i < num_params * 2; i += 2) { + names.emplace_back(args[i].operator std::string()); + arrays.emplace_back(args[i + 1].operator DLTensor*()); + } + std::string bytes; + dmlc::MemoryStringStream strm(&bytes); + dmlc::Stream* fo = &strm; + uint64_t header = kTVMNDArrayListMagic, reserved = 0; + fo->Write(header); + fo->Write(reserved); + fo->Write(names); + { + uint64_t sz = static_cast(arrays.size()); + fo->Write(sz); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(fo, arrays[i]); + } + } + TVMByteArray arr; + arr.data = bytes.c_str(); + arr.size = bytes.length(); + *rv = arr; + }); + +TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict") +.set_body([](TVMArgs args, TVMRetValue *rv) { + std::string bytes = args[0]; + std::vector names; + dmlc::MemoryStringStream memstrm(&bytes); + dmlc::Stream* strm = &memstrm; + uint64_t header, reserved; + CHECK(strm->Read(&header)) + << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) + << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) + << "Invalid parameters file format"; + CHECK(strm->Read(&names)) + << "Invalid parameters file format"; + uint64_t sz; + strm->Read(&sz, sizeof(sz)); + size_t size = static_cast(sz); + CHECK(size == names.size()) + << "Invalid parameters file format"; + tvm::Array ret; + for (size_t i = 0; i < size; ++i) { + tvm::runtime::NDArray temp; + temp.Load(strm); + auto n = tvm::make_node(); + n->name = std::move(names[i]); + n->array = temp; + ret.push_back(NamedNDArray(n)); + } + *rv = ret; + }); + +TVM_REGISTER_NODE_TYPE(NamedNDArrayNode); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h new file mode 100644 index 000000000000..0c32d2bf4742 --- /dev/null +++ b/src/relay/backend/param_dict.h @@ -0,0 +1,43 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file param_dict.h + * \brief Definitions for serializing and deserializing parameter dictionaries. + */ +#ifndef TVM_RELAY_BACKEND_PARAM_DICT_H_ +#define TVM_RELAY_BACKEND_PARAM_DICT_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relay { + +/*! \brief Magic number for NDArray list file */ +constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; + +/*! + * \brief Wrapper node for naming `NDArray`s. + */ +struct NamedNDArrayNode : public ::tvm::Node { + std::string name; + tvm::runtime::NDArray array; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("array", &array); + } + + static constexpr const char* _type_key = "NamedNDArray"; + TVM_DECLARE_NODE_TYPE_INFO(NamedNDArrayNode, Node); +}; + +TVM_DEFINE_NODE_REF(NamedNDArray, NamedNDArrayNode); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_PARAM_DICT_H_ diff --git a/tests/python/relay/test_param_dict.py b/tests/python/relay/test_param_dict.py new file mode 100644 index 000000000000..b398ea8ba2f5 --- /dev/null +++ b/tests/python/relay/test_param_dict.py @@ -0,0 +1,78 @@ +import os +import numpy as np +import tvm +import json +import base64 +from tvm._ffi.base import py_str +from tvm.relay.op import add +from tvm import relay +from tvm import rpc +from tvm.contrib import util, graph_runtime + + +def test_save_load(): + x = np.ones((10, 2)).astype("float32") + y = np.ones((1, 2, 3)).astype("float32") + params = {"x": x, "y": y} + param_bytes = relay.save_param_dict(params) + assert isinstance(param_bytes, bytearray) + param2 = relay.load_param_dict(param_bytes) + assert len(param2) == 2 + np.testing.assert_equal(param2["x"].asnumpy(), x) + np.testing.assert_equal(param2["y"].asnumpy(), y) + + +def test_ndarray_reflection(): + # Make two `NDArrayWrapper`s that point to the same underlying array. + np_array = np.random.uniform(size=(10, 2)).astype("float32") + tvm_array = tvm.nd.array(np_array) + param_dict = {'x': tvm_array, 'y': tvm_array} + assert param_dict['x'].same_as(param_dict['y']) + # Serialize then deserialize `param_dict`. + deser_param_dict = relay.load_param_dict(relay.save_param_dict(param_dict)) + # Make sure the data matches the original data and `x` and `y` contain the same data. + np.testing.assert_equal(deser_param_dict['x'].asnumpy(), tvm_array.asnumpy()) + # Make sure `x` and `y` contain the same data. + np.testing.assert_equal(deser_param_dict['x'].asnumpy(), deser_param_dict['y'].asnumpy()) + + +def test_bigendian_rpc_param(): + """Test big endian rpc when there is a PowerPC RPC server available""" + host = os.environ.get("TVM_POWERPC_TEST_HOST", None) + port = os.environ.get("TVM_POWERPC_TEST_PORT", 9090) + if host is None: + return + + def verify_graph_runtime(remote, target, shape, dtype): + x = relay.var('x') + y = relay.const(1) + z = relay.add(x, y) + func = relay.Function([x], z) + + x_in = np.ones(shape).astype(dtype) + params = {'x': x_in} + graph, lib, params = relay.build(func, target=target, params=params) + + temp = util.tempdir() + path_dso = temp.relpath("dev_lib.o") + lib.save(path_dso) + remote.upload(path_dso) + lib = remote.load_module("dev_lib.o") + ctx = remote.cpu(0) + mod = graph_runtime.create(graph, lib, ctx) + mod.load_params(relay.save_param_dict(params)) + mod.run() + out = mod.get_output(0, tvm.nd.empty(shape, dtype=dtype, ctx=ctx)) + tvm.testing.assert_allclose(x_in + 1, out.asnumpy()) + + print("Test RPC connection to PowerPC...") + remote = rpc.connect(host, port) + target = "llvm -mtriple=powerpc-linux-gnu" + for dtype in ["float32", "float64", "int32", "int8"]: + verify_graph_runtime(remote, target, (10,), dtype) + + +if __name__ == "__main__": + test_save_load() + test_ndarray_reflection() + test_bigendian_rpc_param()