-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay] Port param dict save/load from NNVM (#2620)
- Loading branch information
Showing
7 changed files
with
279 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# pylint: disable=invalid-name | ||
"""Helper utility to save parameter dicts.""" | ||
import tvm | ||
|
||
_save_param_dict = tvm.get_global_func("tvm.relay._save_param_dict") | ||
_load_param_dict = tvm.get_global_func("tvm.relay._load_param_dict") | ||
|
||
def save_param_dict(params): | ||
"""Save parameter dictionary to binary bytes. | ||
The result binary bytes can be loaded by the | ||
GraphModule with API "load_params". | ||
Parameters | ||
---------- | ||
params : dict of str to NDArray | ||
The parameter dictionary. | ||
Returns | ||
------- | ||
param_bytes: bytearray | ||
Serialized parameters. | ||
Examples | ||
-------- | ||
.. code-block:: python | ||
# compile and save the modules to file. | ||
graph, lib, params = tvm.relay.build(func, target=target, params=params) | ||
module = graph_runtime.create(graph, lib, tvm.gpu(0)) | ||
# save the parameters as byte array | ||
param_bytes = tvm.relay.save_param_dict(params) | ||
# We can serialize the param_bytes and load it back later. | ||
# Pass in byte array to module to directly set parameters | ||
module.load_params(param_bytes) | ||
""" | ||
args = [] | ||
for k, v in params.items(): | ||
args.append(k) | ||
args.append(tvm.nd.array(v)) | ||
return _save_param_dict(*args) | ||
|
||
|
||
def load_param_dict(param_bytes): | ||
"""Load parameter dictionary to binary bytes. | ||
Parameters | ||
---------- | ||
param_bytes: bytearray | ||
Serialized parameters. | ||
Returns | ||
------- | ||
params : dict of str to NDArray | ||
The parameter dictionary. | ||
""" | ||
if isinstance(param_bytes, (bytes, str)): | ||
param_bytes = bytearray(param_bytes) | ||
load_arr = _load_param_dict(param_bytes) | ||
return {v.name : v.array for v in load_arr} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file param_dict.cc | ||
* \brief Implementation and registration of parameter dictionary | ||
* serializing/deserializing functions. | ||
*/ | ||
#include "param_dict.h" | ||
|
||
#include <dmlc/memory_io.h> | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
using namespace runtime; | ||
|
||
TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict") | ||
.set_body([](TVMArgs args, TVMRetValue *rv) { | ||
CHECK_EQ(args.size() % 2, 0u); | ||
// `args` is in the form "key, value, key, value, ..." | ||
size_t num_params = args.size() / 2; | ||
std::vector<std::string> names; | ||
names.reserve(num_params); | ||
std::vector<DLTensor*> arrays; | ||
arrays.reserve(num_params); | ||
for (size_t i = 0; i < num_params * 2; i += 2) { | ||
names.emplace_back(args[i].operator std::string()); | ||
arrays.emplace_back(args[i + 1].operator DLTensor*()); | ||
} | ||
std::string bytes; | ||
dmlc::MemoryStringStream strm(&bytes); | ||
dmlc::Stream* fo = &strm; | ||
uint64_t header = kTVMNDArrayListMagic, reserved = 0; | ||
fo->Write(header); | ||
fo->Write(reserved); | ||
fo->Write(names); | ||
{ | ||
uint64_t sz = static_cast<uint64_t>(arrays.size()); | ||
fo->Write(sz); | ||
for (size_t i = 0; i < sz; ++i) { | ||
tvm::runtime::SaveDLTensor(fo, arrays[i]); | ||
} | ||
} | ||
TVMByteArray arr; | ||
arr.data = bytes.c_str(); | ||
arr.size = bytes.length(); | ||
*rv = arr; | ||
}); | ||
|
||
TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict") | ||
.set_body([](TVMArgs args, TVMRetValue *rv) { | ||
std::string bytes = args[0]; | ||
std::vector<std::string> names; | ||
dmlc::MemoryStringStream memstrm(&bytes); | ||
dmlc::Stream* strm = &memstrm; | ||
uint64_t header, reserved; | ||
CHECK(strm->Read(&header)) | ||
<< "Invalid parameters file format"; | ||
CHECK(header == kTVMNDArrayListMagic) | ||
<< "Invalid parameters file format"; | ||
CHECK(strm->Read(&reserved)) | ||
<< "Invalid parameters file format"; | ||
CHECK(strm->Read(&names)) | ||
<< "Invalid parameters file format"; | ||
uint64_t sz; | ||
strm->Read(&sz, sizeof(sz)); | ||
size_t size = static_cast<size_t>(sz); | ||
CHECK(size == names.size()) | ||
<< "Invalid parameters file format"; | ||
tvm::Array<NamedNDArray> ret; | ||
for (size_t i = 0; i < size; ++i) { | ||
tvm::runtime::NDArray temp; | ||
temp.Load(strm); | ||
auto n = tvm::make_node<NamedNDArrayNode>(); | ||
n->name = std::move(names[i]); | ||
n->array = temp; | ||
ret.push_back(NamedNDArray(n)); | ||
} | ||
*rv = ret; | ||
}); | ||
|
||
TVM_REGISTER_NODE_TYPE(NamedNDArrayNode); | ||
|
||
} // namespace relay | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file param_dict.h | ||
* \brief Definitions for serializing and deserializing parameter dictionaries. | ||
*/ | ||
#ifndef TVM_RELAY_BACKEND_PARAM_DICT_H_ | ||
#define TVM_RELAY_BACKEND_PARAM_DICT_H_ | ||
|
||
#include <tvm/node/node.h> | ||
#include <tvm/packed_func_ext.h> | ||
#include <tvm/runtime/ndarray.h> | ||
#include <tvm/runtime/packed_func.h> | ||
|
||
#include <string> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
/*! \brief Magic number for NDArray list file */ | ||
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; | ||
|
||
/*! | ||
* \brief Wrapper node for naming `NDArray`s. | ||
*/ | ||
struct NamedNDArrayNode : public ::tvm::Node { | ||
std::string name; | ||
tvm::runtime::NDArray array; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("name", &name); | ||
v->Visit("array", &array); | ||
} | ||
|
||
static constexpr const char* _type_key = "NamedNDArray"; | ||
TVM_DECLARE_NODE_TYPE_INFO(NamedNDArrayNode, Node); | ||
}; | ||
|
||
TVM_DEFINE_NODE_REF(NamedNDArray, NamedNDArrayNode); | ||
|
||
} // namespace relay | ||
} // namespace tvm | ||
|
||
#endif // TVM_RELAY_BACKEND_PARAM_DICT_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import os | ||
import numpy as np | ||
import tvm | ||
import json | ||
import base64 | ||
from tvm._ffi.base import py_str | ||
from tvm.relay.op import add | ||
from tvm import relay | ||
from tvm import rpc | ||
from tvm.contrib import util, graph_runtime | ||
|
||
|
||
def test_save_load(): | ||
x = np.ones((10, 2)).astype("float32") | ||
y = np.ones((1, 2, 3)).astype("float32") | ||
params = {"x": x, "y": y} | ||
param_bytes = relay.save_param_dict(params) | ||
assert isinstance(param_bytes, bytearray) | ||
param2 = relay.load_param_dict(param_bytes) | ||
assert len(param2) == 2 | ||
np.testing.assert_equal(param2["x"].asnumpy(), x) | ||
np.testing.assert_equal(param2["y"].asnumpy(), y) | ||
|
||
|
||
def test_ndarray_reflection(): | ||
# Make two `NDArrayWrapper`s that point to the same underlying array. | ||
np_array = np.random.uniform(size=(10, 2)).astype("float32") | ||
tvm_array = tvm.nd.array(np_array) | ||
param_dict = {'x': tvm_array, 'y': tvm_array} | ||
assert param_dict['x'].same_as(param_dict['y']) | ||
# Serialize then deserialize `param_dict`. | ||
deser_param_dict = relay.load_param_dict(relay.save_param_dict(param_dict)) | ||
# Make sure the data matches the original data and `x` and `y` contain the same data. | ||
np.testing.assert_equal(deser_param_dict['x'].asnumpy(), tvm_array.asnumpy()) | ||
# Make sure `x` and `y` contain the same data. | ||
np.testing.assert_equal(deser_param_dict['x'].asnumpy(), deser_param_dict['y'].asnumpy()) | ||
|
||
|
||
def test_bigendian_rpc_param(): | ||
"""Test big endian rpc when there is a PowerPC RPC server available""" | ||
host = os.environ.get("TVM_POWERPC_TEST_HOST", None) | ||
port = os.environ.get("TVM_POWERPC_TEST_PORT", 9090) | ||
if host is None: | ||
return | ||
|
||
def verify_graph_runtime(remote, target, shape, dtype): | ||
x = relay.var('x') | ||
y = relay.const(1) | ||
z = relay.add(x, y) | ||
func = relay.Function([x], z) | ||
|
||
x_in = np.ones(shape).astype(dtype) | ||
params = {'x': x_in} | ||
graph, lib, params = relay.build(func, target=target, params=params) | ||
|
||
temp = util.tempdir() | ||
path_dso = temp.relpath("dev_lib.o") | ||
lib.save(path_dso) | ||
remote.upload(path_dso) | ||
lib = remote.load_module("dev_lib.o") | ||
ctx = remote.cpu(0) | ||
mod = graph_runtime.create(graph, lib, ctx) | ||
mod.load_params(relay.save_param_dict(params)) | ||
mod.run() | ||
out = mod.get_output(0, tvm.nd.empty(shape, dtype=dtype, ctx=ctx)) | ||
tvm.testing.assert_allclose(x_in + 1, out.asnumpy()) | ||
|
||
print("Test RPC connection to PowerPC...") | ||
remote = rpc.connect(host, port) | ||
target = "llvm -mtriple=powerpc-linux-gnu" | ||
for dtype in ["float32", "float64", "int32", "int8"]: | ||
verify_graph_runtime(remote, target, (10,), dtype) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_save_load() | ||
test_ndarray_reflection() | ||
test_bigendian_rpc_param() |