diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 8b7954f0ed48..7e919586b0c0 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -72,6 +72,7 @@ def __init__(self, module, ctx): self._set_input = module["set_input"] self._run = module["run"] self._get_output = module["get_output"] + self._load_params = module["load_params"] self.ctx = ctx def set_input(self, key=None, value=None, **params): @@ -120,6 +121,16 @@ def get_output(self, index, out): self._get_output(index, out) return out + def load_params(self, params_bytes): + """Load parameters from serialized byte array of parameter dict. + + Parameters + ---------- + params_bytes : bytearray + The serialized parameter dict. + """ + self._load_params(bytearray(params_bytes)) + def __getitem__(self, key): """Get internal module function diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index f305367d4a2d..2cf6a1fb1330 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -299,7 +299,7 @@ class GraphRuntime : public ModuleNode { } CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format"; } - bool LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor); + void LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor); /*! \brief Setup the temporal storage */ void SetupStorage(); /*! \brief Setup the executors */ @@ -353,7 +353,7 @@ class GraphRuntime : public ModuleNode { }; -bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) { +void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { uint64_t header, reserved; CHECK(strm->Read(&header, sizeof(header))) << "Invalid DLTensor file format"; @@ -362,30 +362,37 @@ bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) { CHECK(header == kTVMNDArrayMagic) << "Invalid DLTensor file format"; - CHECK(strm->Read(&tensor->ctx, sizeof(tensor->ctx))) + DLTensor tensor; + CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx))) << "Invalid DLTensor file format"; - CHECK(strm->Read(&tensor->ndim, sizeof(tensor->ndim))) + CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim))) << "Invalid DLTensor file format"; - CHECK(strm->Read(&tensor->dtype, sizeof(tensor->dtype))) + CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) << "Invalid DLTensor file format"; - - int ndim = tensor->ndim; - CHECK(strm->Read(tensor->shape, sizeof(int64_t) * ndim)) + std::vector shape(tensor.ndim); + CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) << "Invalid DLTensor file format"; - - int64_t size = 1; - int type_size = tensor->dtype.bits / 8; - for (int i = 0; i < ndim; ++i) { - size *= tensor->shape[i]; + CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch"; + CHECK(tensor.dtype.bits == dst->dtype.bits && + tensor.dtype.code == dst->dtype.code && + tensor.dtype.lanes == dst->dtype.lanes) << "param type mismatch"; + for (int i = 0; i < tensor.ndim; ++i) { + CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch"; + } + size_t bits = dst->dtype.bits * dst->dtype.lanes; + size_t size = (bits + 7) / 8; + for (int i = 0; i < dst->ndim; ++i) { + size *= dst->shape[i]; } int64_t data_byte_size; CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size))) << "Invalid DLTensor file format"; - CHECK(data_byte_size == type_size * size) + CHECK(data_byte_size == size) << "Invalid DLTensor file format"; - CHECK(strm->Read(tensor->data, type_size * size)) + std::vector bytes(data_byte_size + 1); + CHECK(strm->Read(&bytes[0], data_byte_size)) << "Invalid DLTensor file format"; - return true; + TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size)); } void GraphRuntime::LoadParams(dmlc::Stream* strm) { @@ -406,11 +413,11 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { CHECK(size == names.size()) << "Invalid parameters file format"; - for (size_t i = 0; i < size; ++i) { uint32_t in_idx = GetInputIndex(names[i]); - CHECK(LoadDLTensor(strm, &data_entry_[this->entry_id(input_nodes_[in_idx], 0)])) - << "Invalid parameters file format"; + uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); + CHECK_LT(eid, data_entry_.size()); + LoadDLTensor(strm, &data_entry_[eid]); } } @@ -461,6 +468,7 @@ void GraphRuntime::SetupStorage() { // Assign the pooled entries. for (size_t i = 0; i < data_entry_.size(); ++i) { int storage_id = attrs_.storage_id[i]; + CHECK_LT(static_cast(storage_id), storage_pool_.size()); data_entry_[i] = *storage_pool_[storage_id]; data_entry_[i].shape = const_cast(attrs_.shape[i].data()); data_entry_[i].ndim = static_cast(attrs_.shape[i].size());