Skip to content

Commit

Permalink
move storage type vector from nnvm to mxnet (apache#7054)
Browse files Browse the repository at this point in the history
* move storage type vector from nnvm to mxnet

* update nnvm

* update nnvm
  • Loading branch information
eric-haibin-lin committed Jul 25, 2017
1 parent 4c3e8b5 commit cf5747b
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 20 deletions.
30 changes: 30 additions & 0 deletions include/mxnet/graph_attr_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*!
* Copyright (c) 2016 by Contributors
* \file graph_attr_types.h
* \brief Data structures that can appear in graph attributes.
*/
#ifndef MXNET_GRAPH_ATTR_TYPES_H_
#define MXNET_GRAPH_ATTR_TYPES_H_

#include <vector>

namespace mxnet {

/*!
* \brief The result holder of storage type of each NodeEntry in the graph.
* \note Stored under graph.attrs["storage_type"], provided by Pass "InferStorageType"
*
* \code
* Graph g = ApplyPass(src_graph, "InferStorageType");
* const StorageVector& stypes = g.GetAttr<StorageTypeVector>("storage_type");
* // get shape by entry id
* int entry_type = stypes[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferStorageType
*/
using StorageTypeVector = std::vector<int>;

} // namespace mxnet

#endif // MXNET_GRAPH_ATTR_TYPES_H_
3 changes: 2 additions & 1 deletion src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <mxnet/engine.h>
#include <mxnet/ndarray.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph_attr_types.h>

#include <memory>
Expand Down Expand Up @@ -98,7 +99,7 @@ inline void CastNonDefaultStorage(const std::vector<NDArray>& dst,
}

// Check if any storage type is not default storage
inline bool ContainsNonDefaultStorage(const nnvm::StorageTypeVector& vstorage) {
inline bool ContainsNonDefaultStorage(const StorageTypeVector& vstorage) {
for (auto& i : vstorage) {
if (i != kUndefinedStorage && i != kDefaultStorage) return true;
}
Expand Down
1 change: 1 addition & 0 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <mxnet/base.h>
#include <mxnet/operator.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include "../common/utils.h"
#include "./exec_pass.h"
Expand Down
3 changes: 2 additions & 1 deletion src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <mxnet/operator.h>
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph.h>
#include <nnvm/graph_attr_types.h>
#include <vector>
Expand Down Expand Up @@ -149,7 +150,7 @@ Graph InferType(Graph graph,
* The index of StorageTypeVector is given by graph.indexed_graph().entry_id.
*/
Graph InferStorageType(Graph graph,
nnvm::StorageTypeVector storage_type_inputs,
StorageTypeVector storage_type_inputs,
const std::string& storage_type_attr_key = "");

} // namespace exec
Expand Down
27 changes: 13 additions & 14 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ void HandleInferTypeError(const size_t num_forward_inputs,

void HandleInferStorageTypeError(const size_t num_forward_inputs,
const nnvm::IndexedGraph& idx,
const nnvm::StorageTypeVector& inferred_stypes) {
const StorageTypeVector& inferred_stypes) {
int cnt = 10;
std::ostringstream oss;
for (size_t i = 0; i < num_forward_inputs; ++i) {
Expand Down Expand Up @@ -505,7 +505,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
data_entry_.resize(idx.num_node_entries());
nnvm::ShapeVector arg_shapes;
nnvm::DTypeVector arg_dtypes;
nnvm::StorageTypeVector arg_stypes;
StorageTypeVector arg_stypes;
for (size_t i = 0; i < num_forward_inputs_; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
const std::string& arg_name = idx[nid].source->attrs.name;
Expand Down Expand Up @@ -555,7 +555,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
g = InferStorageType(std::move(g), arg_stypes, "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<nnvm::StorageTypeVector>("storage_type"));
g.GetAttr<StorageTypeVector>("storage_type"));
}

// Initialize the rest attributes of the graph.
Expand All @@ -573,7 +573,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const nnvm::StorageTypeVector& inferred_stypes,
const StorageTypeVector& inferred_stypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
Expand Down Expand Up @@ -679,7 +679,7 @@ NDArray ReshapeOrCreate(const std::string& name,
void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const nnvm::StorageTypeVector& inferred_stypes,
const StorageTypeVector& inferred_stypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
Expand Down Expand Up @@ -802,13 +802,13 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
const auto& idx = g.indexed_graph();
// dispatch based on stype per operator
const auto& vstorage_type = g.GetAttr<nnvm::StorageTypeVector>("storage_type");
nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage);
const auto& vstorage_type = g.GetAttr<StorageTypeVector>("storage_type");
StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage);
for (size_t nid = 0; nid < idx.num_nodes(); nid++) {
const auto& inode = idx[nid];
auto num_outputs = inode.source->num_outputs();
auto num_inputs = inode.inputs.size();
nnvm::StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage);
StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage);
for (size_t i = 0; i < num_inputs; i++) {
auto e = inode.inputs[i];
vs[i] = vstorage_type[idx.entry_id(e)];
Expand Down Expand Up @@ -919,7 +919,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
const nnvm::IndexedGraph& idx = g.indexed_graph();
nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
nnvm::StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
for (size_t i = 0; i < num_forward_inputs_; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
const std::string& name = idx[nid].source->attrs.name;
Expand Down Expand Up @@ -951,21 +951,21 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
g = InferStorageType(std::move(g), arg_stypes, "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<nnvm::StorageTypeVector>("storage_type"));
g.GetAttr<StorageTypeVector>("storage_type"));
}

// Create in_args, arg_grads, and aux_states using
// the inferred shapes and dtypes.
if (nullptr == shared_buffer) { // regular simple bind
InitArguments(idx, g.GetAttr<nnvm::ShapeVector>("shape"),
g.GetAttr<nnvm::DTypeVector>("dtype"),
g.GetAttr<nnvm::StorageTypeVector>("storage_type"),
g.GetAttr<StorageTypeVector>("storage_type"),
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec);
} else { // simple bind using shared data arrays and shared_exec
InitArguments(idx, g.GetAttr<nnvm::ShapeVector>("shape"),
g.GetAttr<nnvm::DTypeVector>("dtype"),
g.GetAttr<nnvm::StorageTypeVector>("storage_type"),
g.GetAttr<StorageTypeVector>("storage_type"),
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
grad_req_types, shared_arg_names, shared_exec,
shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec);
Expand Down Expand Up @@ -1018,7 +1018,6 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
// initialize the memory of each entries
void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
using nnvm::DTypeVector;
using nnvm::StorageTypeVector;
using nnvm::ShapeVector;
using nnvm::StorageVector;
// get the graph
Expand Down Expand Up @@ -1169,7 +1168,7 @@ void GraphExecutor::InitCachedOps() {
const auto& vctx = graph_.GetAttr<ContextVector>("context");
const auto& addto_entry = graph_.GetAttr<std::vector<int> >("addto_entry");
const auto& skip_plus_node = graph_.GetAttr<std::vector<int> >("skip_plus_node");
const auto& vstorage_type = graph_.GetAttr<nnvm::StorageTypeVector>("storage_type");
const auto& vstorage_type = graph_.GetAttr<StorageTypeVector>("storage_type");

op_nodes_.resize(idx.num_nodes());
// setup the array and requirements.
Expand Down
4 changes: 2 additions & 2 deletions src/executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class GraphExecutor : public Executor {
void InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const nnvm::StorageTypeVector& inferred_stypes,
const StorageTypeVector& inferred_stypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
Expand All @@ -139,7 +139,7 @@ class GraphExecutor : public Executor {
void InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const nnvm::StorageTypeVector& inferred_stypes,
const StorageTypeVector& inferred_stypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
Expand Down
3 changes: 2 additions & 1 deletion src/executor/infer_graph_attr_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
#include "./exec_pass.h"

namespace mxnet {
Expand Down Expand Up @@ -314,7 +315,7 @@ nnvm::Graph InferType(nnvm::Graph graph,
}

nnvm::Graph InferStorageType(nnvm::Graph graph,
nnvm::StorageTypeVector storage_type_inputs,
StorageTypeVector storage_type_inputs,
const std::string& storage_type_attr_key) {
using dmlc::any;
if (storage_type_inputs.size() != 0) {
Expand Down

0 comments on commit cf5747b

Please sign in to comment.