Skip to content

Commit

Permalink
[RUNTIME] Fix graph runtime for gpu (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Sep 26, 2017
1 parent c468558 commit fd864c5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
11 changes: 11 additions & 0 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, module, ctx):
self._set_input = module["set_input"]
self._run = module["run"]
self._get_output = module["get_output"]
self._load_params = module["load_params"]
self.ctx = ctx

def set_input(self, key=None, value=None, **params):
Expand Down Expand Up @@ -120,6 +121,16 @@ def get_output(self, index, out):
self._get_output(index, out)
return out

def load_params(self, params_bytes):
"""Load parameters from serialized byte array of parameter dict.
Parameters
----------
params_bytes : bytearray
The serialized parameter dict.
"""
self._load_params(bytearray(params_bytes))

def __getitem__(self, key):
"""Get internal module function
Expand Down
46 changes: 27 additions & 19 deletions src/runtime/graph/graph_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ class GraphRuntime : public ModuleNode {
}
CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format";
}
bool LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor);
void LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor);
/*! \brief Setup the temporal storage */
void SetupStorage();
/*! \brief Setup the executors */
Expand Down Expand Up @@ -353,7 +353,7 @@ class GraphRuntime : public ModuleNode {
};


bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
<< "Invalid DLTensor file format";
Expand All @@ -362,30 +362,37 @@ bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";

CHECK(strm->Read(&tensor->ctx, sizeof(tensor->ctx)))
DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->ndim, sizeof(tensor->ndim)))
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->dtype, sizeof(tensor->dtype)))
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
<< "Invalid DLTensor file format";

int ndim = tensor->ndim;
CHECK(strm->Read(tensor->shape, sizeof(int64_t) * ndim))
std::vector<int64_t> shape(tensor.ndim);
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
<< "Invalid DLTensor file format";

int64_t size = 1;
int type_size = tensor->dtype.bits / 8;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
CHECK(tensor.dtype.bits == dst->dtype.bits &&
tensor.dtype.code == dst->dtype.code &&
tensor.dtype.lanes == dst->dtype.lanes) << "param type mismatch";
for (int i = 0; i < tensor.ndim; ++i) {
CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch";
}
size_t bits = dst->dtype.bits * dst->dtype.lanes;
size_t size = (bits + 7) / 8;
for (int i = 0; i < dst->ndim; ++i) {
size *= dst->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size)
CHECK(data_byte_size == size)
<< "Invalid DLTensor file format";
CHECK(strm->Read(tensor->data, type_size * size))
std::vector<uint8_t> bytes(data_byte_size + 1);
CHECK(strm->Read(&bytes[0], data_byte_size))
<< "Invalid DLTensor file format";
return true;
TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size));
}

void GraphRuntime::LoadParams(dmlc::Stream* strm) {
Expand All @@ -406,11 +413,11 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {

CHECK(size == names.size())
<< "Invalid parameters file format";

for (size_t i = 0; i < size; ++i) {
uint32_t in_idx = GetInputIndex(names[i]);
CHECK(LoadDLTensor(strm, &data_entry_[this->entry_id(input_nodes_[in_idx], 0)]))
<< "Invalid parameters file format";
uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
CHECK_LT(eid, data_entry_.size());
LoadDLTensor(strm, &data_entry_[eid]);
}
}

Expand Down Expand Up @@ -461,6 +468,7 @@ void GraphRuntime::SetupStorage() {
// Assign the pooled entries.
for (size_t i = 0; i < data_entry_.size(); ++i) {
int storage_id = attrs_.storage_id[i];
CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size());
data_entry_[i] = *storage_pool_[storage_id];
data_entry_[i].shape = const_cast<int64_t*>(attrs_.shape[i].data());
data_entry_[i].ndim = static_cast<int>(attrs_.shape[i].size());
Expand Down

0 comments on commit fd864c5

Please sign in to comment.