Skip to content

Commit

Permalink
Support Module based interface runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene committed Jun 9, 2020
1 parent 2e1ef8e commit 4ca9e56
Show file tree
Hide file tree
Showing 10 changed files with 1,062 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,27 @@
#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_

#include <dlpack/dlpack.h>
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>

#include <tvm/runtime/graph_runtime_factory.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <string>
#include <numeric>

namespace tvm {
namespace runtime {

/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
{ \
int ret = (func); \
CHECK_EQ(ret, 0) << TVMGetLastError(); \
#define TVM_CCALL(func) \
{ \
int ret = (func); \
CHECK_EQ(ret, 0) \
<< TVMGetLastError(); \
}

/*! \brief Magic number for NDArray list file */
Expand All @@ -64,7 +66,7 @@ struct TVMOpParam {
* This runtime can be acccesibly in various language via
* TVM runtime PackedFunc API.
*/
class TVM_DLL GraphRuntime : public ModuleNode {
class TVM_DLL GraphRuntime : public GraphRuntimeFactory {
struct OpArgs {
std::vector<DLTensor> args;
std::vector<TVMValue> arg_values;
Expand All @@ -79,12 +81,15 @@ class TVM_DLL GraphRuntime : public ModuleNode {
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
virtual PackedFunc GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self);

/*!
* \return The type key of the executor.
*/
const char* type_key() const final { return "GraphRuntime"; }
const char* type_key() const final {
return "GraphRuntime";
}
void Run();

/*!
Expand All @@ -94,10 +99,13 @@ class TVM_DLL GraphRuntime : public ModuleNode {
* processor.
* \param ctxs The context of the host and devices where graph nodes will be
* executed on.
* \param params The params of graph.
*/

void Init(const std::string& graph_json, tvm::runtime::Module module,
const std::vector<TVMContext>& ctxs);
void Init(const std::string& graph_json,
tvm::runtime::Module module,
const std::vector<TVMContext>& ctxs,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params={});

/*!
* \brief Get the input index given the name of input.
Expand Down Expand Up @@ -167,9 +175,66 @@ class TVM_DLL GraphRuntime : public ModuleNode {
* \brief Get total number of nodes.
* \return Total number of nodes.
*/
uint32_t GetNumOfNodes() const { return static_cast<uint32_t>(nodes_.size()); }
uint32_t GetNumOfNodes() const {
return static_cast<uint32_t>(nodes_.size());
}

std::string GetNodeName(uint32_t nid) const {
return nodes_[nid].name;
}

/*!
* \brief Set graph json value.
* \param graph_json The graph json value we want to set.
*/
void SetGraphJson(const std::string& graph_json) {
graph_json_ = graph_json;
}

std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; }
/*!
* \brief Get the graph json.
* \return The graph json.
*/
std::string GetGraphJson() const {
return graph_json_;
}

/*!
* \brief Set the graph params.
* \param params The graph params value we want to set.
*/
void SetParams(const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
params_ = params;

// upload big arrays first to avoid memory issue in rpc mode
std::vector<std::string> keys;
for (const auto& p : params_) {
keys.emplace_back(p.first);
}
std::sort(std::begin(keys), std::end(keys),
[this](const std::string& lhs, const std::string& rhs) -> bool {
auto lhs_shape = params_[lhs].Shape();
auto rhs_shape = params_[rhs].Shape();
auto lhs_prod = std::accumulate(std::begin(lhs_shape), std::end(lhs_shape), 1, std::multiplies<int64_t>());
auto rhs_prod = std::accumulate(std::begin(rhs_shape), std::end(rhs_shape), 1, std::multiplies<int64_t>());
return lhs_prod > rhs_prod;
});

for (const auto& key : keys) {
int in_idx = this->GetInputIndex(key);
if (in_idx >= 0) {
this->SetInput(in_idx, const_cast<DLTensor*>(params_[key].operator->()));
}
}
}

/*!
* \brief Get the graph params.
* \return The graph params.
*/
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() const {
return params_;
}

protected:
// Memory pool entry.
Expand All @@ -184,7 +249,7 @@ class TVM_DLL GraphRuntime : public ModuleNode {
uint32_t index;
uint32_t version;
// JSON Loader
void Load(dmlc::JSONReader* reader) {
void Load(dmlc::JSONReader *reader) {
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&node_id);
Expand All @@ -211,7 +276,7 @@ class TVM_DLL GraphRuntime : public ModuleNode {
// control deps
std::vector<uint32_t> control_deps;
// JSON Loader
void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) {
void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) {
int bitmask = 0;
std::string key, value;
reader->BeginObject();
Expand All @@ -231,10 +296,10 @@ class TVM_DLL GraphRuntime : public ModuleNode {
bitmask |= 8;
}
}
CHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "invalid format";
CHECK_EQ(bitmask, 1|2|4|8) << "invalid format";
}
// JSON Loader
void Load(dmlc::JSONReader* reader) {
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
int bitmask = 0;
std::string key;
Expand All @@ -256,17 +321,17 @@ class TVM_DLL GraphRuntime : public ModuleNode {
LOG(FATAL) << "do not support key " << key;
}
}
CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format";
CHECK_EQ(bitmask, 1|2|4) << "invalid format";
}
};
struct GraphAttr {
size_t storage_num_not_alloctaed{0};
std::vector<int> storage_id;
std::vector<int> device_index;
std::vector<std::string> dltype;
std::vector<std::vector<int64_t>> shape;
std::vector<std::vector<int64_t> > shape;
// The graph attribute fields.
void Load(dmlc::JSONReader* reader) {
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
int bitmask = 0;
std::string key, type;
Expand Down Expand Up @@ -324,37 +389,37 @@ class TVM_DLL GraphRuntime : public ModuleNode {
CHECK(!reader->NextArrayItem());
}
}
CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format";
CHECK_EQ(bitmask, 1|2|4) << "invalid format";
}
};
// The graph attribute fields.
void Load(dmlc::JSONReader* reader) {
reader->BeginObject();
int bitmask = 0;
std::string key;
while (reader->NextObjectItem(&key)) {
if (key == "nodes") {
reader->Read(&nodes_);
bitmask |= 1;
} else if (key == "arg_nodes") {
reader->Read(&input_nodes_);
bitmask |= 2;
} else if (key == "node_row_ptr") {
reader->Read(&node_row_ptr_);
bitmask |= 4;
} else if (key == "heads") {
reader->Read(&outputs_);
bitmask |= 8;
} else if (key == "attrs") {
reader->Read(&attrs_);
bitmask |= 16;
} else if (key == "metadata") {
break;
} else {
LOG(FATAL) << "key " << key << " is not supported";
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
int bitmask = 0;
std::string key;
while (reader->NextObjectItem(&key)) {
if (key == "nodes") {
reader->Read(&nodes_);
bitmask |= 1;
} else if (key == "arg_nodes") {
reader->Read(&input_nodes_);
bitmask |= 2;
} else if (key == "node_row_ptr") {
reader->Read(&node_row_ptr_);
bitmask |= 4;
} else if (key == "heads") {
reader->Read(&outputs_);
bitmask |= 8;
} else if (key == "attrs") {
reader->Read(&attrs_);
bitmask |= 16;
} else if (key == "metadata") {
break;
} else {
LOG(FATAL) << "key " << key << " is not supported";
}
}
}
CHECK_EQ(bitmask, 1 | 2 | 4 | 8 | 16) << "invalid format";
CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format";
}
/*! \brief Setup the temporal storage */
void SetupStorage();
Expand All @@ -367,14 +432,21 @@ class TVM_DLL GraphRuntime : public ModuleNode {
* \param num_inputs Number of inputs.
* \return The created executor.
*/
std::pair<std::function<void()>, std::shared_ptr<OpArgs>> CreateTVMOp(
const TVMOpParam& attrs, const std::vector<DLTensor>& args, size_t num_inputs);
std::pair<std::function<void()>, std::shared_ptr<OpArgs> > CreateTVMOp(
const TVMOpParam& attrs, const std::vector<DLTensor>& args,
size_t num_inputs);
// Get node entry index.
uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; }
uint32_t entry_id(uint32_t nid, uint32_t index) const {
return node_row_ptr_[nid] + index;
}
// Get node entry index.
uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); }
uint32_t entry_id(const NodeEntry& e) const {
return entry_id(e.node_id, e.index);
}
// Number of node entries.
uint32_t num_node_entries() const { return node_row_ptr_.back(); }
uint32_t num_node_entries() const {
return node_row_ptr_.back();
}
/*! \brief The graph nodes. */
std::vector<Node> nodes_;
/*! \brief The argument nodes. */
Expand All @@ -389,6 +461,10 @@ class TVM_DLL GraphRuntime : public ModuleNode {
std::vector<NodeEntry> outputs_;
/*! \brief Additional graph attributes. */
GraphAttr attrs_;
/*! \brief The execution graph. */
std::string graph_json_;
/*! \brief The params. */
std::unordered_map<std::string, tvm::runtime::NDArray> params_;
/*! \brief The code module that contains both host and device code. */
tvm::runtime::Module module_;
/*! \brief Execution context of all devices including the host. */
Expand All @@ -400,7 +476,7 @@ class TVM_DLL GraphRuntime : public ModuleNode {
/*! \brief Data alignment of each node. */
std::vector<size_t> data_alignment_;
/*! \brief Operator on each node. */
std::vector<std::function<void()>> op_execs_;
std::vector<std::function<void()> > op_execs_;
};

std::vector<TVMContext> GetAllContext(const TVMArgs& args);
Expand Down
Loading

0 comments on commit 4ca9e56

Please sign in to comment.