From 93534ece40f2abe9ff47083e556eba83e9a2bc76 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 9 Jul 2016 12:03:29 -0700 Subject: [PATCH] [SYMBOLIC] Add symbolic API (#2) * [SYMBOLIC] Add symbolic API * Update Testcase to nnvm --- nnvm/include/nnvm/graph.h | 27 ++- nnvm/include/nnvm/node.h | 17 +- nnvm/include/nnvm/op.h | 85 +++++-- nnvm/include/nnvm/op_attr_types.h | 42 ++++ nnvm/include/nnvm/symbolic.h | 148 ++++++++++++ nnvm/src/core/graph_attr_types.cc | 2 +- nnvm/src/core/op.cc | 25 +- nnvm/src/core/symbolic.cc | 369 ++++++++++++++++++++++++++++++ nnvm/src/pass/saveload_json.cc | 2 +- nnvm/src/test_main.cc | 6 +- nnvm/tests/cpp/op_test.cc | 8 +- nnvm/tests/cpp/tuple_test.cc | 8 +- 12 files changed, 680 insertions(+), 59 deletions(-) create mode 100644 nnvm/include/nnvm/op_attr_types.h create mode 100644 nnvm/include/nnvm/symbolic.h create mode 100644 nnvm/src/core/symbolic.cc diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 296f62bf839d..5269f8c33e19 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -14,11 +14,13 @@ #include #include "./base.h" #include "./node.h" +#include "./symbolic.h" namespace nnvm { /*! * \brief Symbolic computation graph. + * This is the intermediate representation for optimization pass. */ class Graph { public: @@ -30,16 +32,18 @@ class Graph { * and can be shared across multiple Instance of graph */ std::unordered_map > attrs; - /*! - * \brief perform a Post Order DFS visit to each node in the graph. - * This order is deterministic and is also topoligical sorted. - * \param fvisit a function of type std::function&)> - * \tparam FVisit The function type to perform the visit. - */ - template - inline void DFSVisit(FVisit fvisit) const; }; +/*! + * \brief perform a Post Order DFS visit to each node in the graph. + * This order is deterministic and is also topoligical sorted. + * \param heads The heads in the graph. + * \param fvisit a function of type std::function&)> + * \tparam FVisit The function type to perform the visit. + */ +template +inline void DFSVisit(const std::vector& heads, FVisit fvisit); + // inline function implementations template & heads, } template -inline void Graph::DFSVisit(FVisit fvisit) const { +inline void DFSVisit(const std::vector& heads, + FVisit fvisit) { typedef const std::shared_ptr* GNode; - std::vector head_nodes(outputs.size()); - std::transform(outputs.begin(), outputs.end(), head_nodes.begin(), + std::vector head_nodes(heads.size()); + std::transform(heads.begin(), heads.end(), head_nodes.begin(), [](const NodeEntry& e)->GNode { return &e.node; }); diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 51e7a6049001..dd24c3bd6fa3 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -72,6 +72,8 @@ class Node { inline bool is_variable() const; /*! \return number of outputs from this node */ inline uint32_t num_outputs() const; + /*! \return number of inputs from this node */ + inline uint32_t num_inputs() const; /*! * \brief create a new empty shared_ptr of Node. * \return a created empty node. @@ -86,10 +88,19 @@ inline bool Node::is_variable() const { inline uint32_t Node::num_outputs() const { if (is_variable()) return 1; - if (this->op->num_outputs >= 0) { - return static_cast(this->op->num_outputs); + if (this->op->get_num_outputs == nullptr) { + return this->op->num_outputs; } else { - return this->op->get_num_outputs(*this); + return this->op->get_num_outputs(this->attrs); + } +} + +inline uint32_t Node::num_inputs() const { + if (is_variable()) return 1; + if (this->op->get_num_inputs == nullptr) { + return this->op->num_inputs; + } else { + return this->op->get_num_inputs(this->attrs); } } diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index a3467c5a6bcf..bcbd4d5cb939 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "./base.h" @@ -22,8 +23,8 @@ template class OpMap; class OpRegistryEntry; -/*! \brief constant to indicate variable length inout and output */ -static const int kVarg = -1; +/*! \brief constant to indicate it take any length of positional inputs */ +static const uint32_t kVarg = std::numeric_limits::max(); /*! * \brief Operator structure. @@ -79,23 +80,31 @@ class Op { /*! * \brief number of inputs to the operator, * -1 means it is variable length + * When get_num_inputs is presented, + * the number will be decided by get_num_inputs instead. + * \sa get_num_inputs */ - int num_inputs = 0; + uint32_t num_inputs = 1; /*! * \brief number of outputs of the operator - * -1 means it is variable length + * When get_num_outputs is presented. * The number of outputs will be decided by * get_num_outputs function * \sa get_num_outputs */ - int num_outputs = 1; + uint32_t num_outputs = 1; /*! * \brief get number of outputs given information about the node. - * This is only valid when num_outputs == -1. - * \param node The constructed node. + * \param attrs The attribute of the node * \return number of outputs. */ - int (*get_num_outputs)(const Node& node) = nullptr; + uint32_t (*get_num_outputs)(const NodeAttrs& attrs) = nullptr; + /*! + * \brief get number of inputs given information about the node. + * \param attrs The attribute of the node + * \return number of inputs + */ + uint32_t (*get_num_inputs)(const NodeAttrs& attrs) = nullptr; /*! * \brief Attribute parser to parse the NodeAttrs information. * @@ -143,19 +152,25 @@ class Op { * \param n The number of inputs to be set. * \return reference to self. */ - inline Op& set_num_inputs(int n); // NOLINT(*) + inline Op& set_num_inputs(uint32_t n); // NOLINT(*) + /*! + * \brief Set the get_num_outputs function. + * \param fn The function to be set. + * \return reference to self. + */ + inline Op& set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*) /*! * \brief Set the num_outputs * \param n The number of outputs to be set. * \return reference to self. */ - inline Op& set_num_outputs(int n); // NOLINT(*) + inline Op& set_num_outputs(uint32_t n); // NOLINT(*) /*! * \brief Set the get_num_outputs function. * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_outputs(int (*fn)(const Node& node)); // NOLINT(*) + inline Op& set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*) /*! * \brief Set the attr_parser function. * \param fn The number of outputs to be set. @@ -180,6 +195,7 @@ class Op { static const Op* Get(const std::string& op_name); /*! * \brief Get additional registered attribute about operators. + * If nothing has been registered, an empty OpMap will be returned. * \param attr_name The name of the attribute. * \return An OpMap of specified attr_name. * \tparam ValueType The type of the attribute. @@ -197,7 +213,7 @@ class Op { // internal constructor Op(); // get const reference to certain attribute - static const any& GetAttrMap(const std::string& key); + static const any* GetAttrMap(const std::string& key); // update the attribute OpMap static void UpdateAttrMap(const std::string& key, std::function updater); @@ -217,6 +233,13 @@ class OpMap { * \return the const reference to the content value. */ inline const ValueType& operator[](const Op* op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + inline const ValueType& get(const Op* op, const ValueType& def_value) const; /*! * \brief Check if the map has op as key. * \param op The key to the map @@ -262,8 +285,18 @@ class OpMap { // member function of Op template inline const OpMap& Op::GetAttr(const std::string& key) { - const any& ref = GetAttrMap(key); - return nnvm::get >(ref); + const any* ref = GetAttrMap(key); + if (ref == nullptr) { + UpdateAttrMap(key, [key](any* pmap) { + if (pmap->empty()) { + OpMap pm; + pm.attr_name_ = key; + *pmap = std::move(pm); + } + }); + ref = GetAttrMap(key); + } + return nnvm::get >(*ref); } template @@ -273,7 +306,7 @@ inline Op& Op::attr( // NOLINT(*) if (pmap->empty()) { OpMap pm; pm.attr_name_ = attr_name; - *pmap = pm; + *pmap = std::move(pm); } CHECK_EQ(pmap->type(), typeid(OpMap)) << "Attribute " << attr_name @@ -301,18 +334,22 @@ inline Op& Op::describe(const std::string& descr) { // NOLINT(*) return *this; } -inline Op& Op::set_num_inputs(int n) { // NOLINT(*) +inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) this->num_inputs = n; return *this; } -inline Op& Op::set_num_outputs(int n) { // NOLINT(*) +inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*) + this->get_num_inputs = fn; + return *this; +} + +inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) this->num_outputs = n; return *this; } -inline Op& Op::set_num_outputs(int (*fn)(const Node& node)) { // NOLINT(*) - this->num_outputs = kVarg; +inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*) this->get_num_outputs = fn; return *this; } @@ -338,6 +375,16 @@ inline const ValueType& OpMap::operator[](const Op* op) const { return data_[idx].first; } +template +inline const ValueType& OpMap::get(const Op* op, const ValueType& def_value) const { + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second) { + return data_[idx].first; + } else { + return def_value; + } +} + } // namespace nnvm #endif // NNVM_OP_H_ diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h new file mode 100644 index 000000000000..bfcf5f6b8eaa --- /dev/null +++ b/nnvm/include/nnvm/op_attr_types.h @@ -0,0 +1,42 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file op_attr_types.h + * \brief Data structures that can appear in operator attributes. + */ +#ifndef NNVM_OP_ATTR_TYPES_H_ +#define NNVM_OP_ATTR_TYPES_H_ + +#include +#include +#include + +namespace nnvm { + +// These types are optional attributes in each op +// Some of them are needed for certain pass. + +/*! + * \brief Return list of input arguments names of each operator. + * + * \param attrs The attributes of the node. + * \return list of inputs + * \note Register under "FListInputNames", default return {"data"}. + * + * FListInputNames enables automatic variable creation for missing arguments. + */ +using FListInputNames = std::function (const NodeAttrs& attrs)>; + +/*! + * \brief Return list of output arguments names of each operator. + * + * \param attrs The attributes of the node. + * \return list of inputs + * \note Register under "FListOutputNames", default return {"outputs"}. + * + * FListOutputNames customized naming for operator outputs. + */ +using FListOutputNames = std::function (const NodeAttrs& attrs)>; + +} // namespace nnvm + +#endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h new file mode 100644 index 000000000000..cba8c3bf8473 --- /dev/null +++ b/nnvm/include/nnvm/symbolic.h @@ -0,0 +1,148 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file symbolic.h + * \brief Symbolic graph construction API + * + * This API is optional, but useful to allow user + * to construct NNVM Graph easily, and quickly create + * front-end host languages. + */ +#ifndef NNVM_SYMBOLIC_H_ +#define NNVM_SYMBOLIC_H_ + +#include +#include +#include + +#include "./base.h" +#include "./node.h" + +namespace nnvm { +/*! + * \brief Symbol is used to represent the + */ +class Symbol { + public: + /*! \brief option passed to ListAttr */ + enum ListAttrOption { + /*! \brief recursively list all attributes */ + kRecursive, + /*! \brief only list attributes in current node */ + kShallow + }; + + /*! \brief output entries contained in the symbol */ + std::vector outputs; + + /*! + * \brief copy the symbol + * \return a deep copy of the symbolic graph. + */ + Symbol Copy() const; + /*! + * \brief print the symbol info to output stream. + * \param os the output stream we like to print to + */ + void Print(std::ostream &os) const; // NOLINT(*) + /*! + * \brief get the index th element from the returned tuple. + * \param index index of multi output + * \return the symbol corresponds to the indexed element. + */ + Symbol operator[] (size_t index) const; + /*! + * \brief List the arguments names. + * + * The position of the returned list also corresponds to calling position in operator() + * \return the arguments list of this symbol, they can be either named or unnamed (empty string). + */ + std::vector ListArguments() const; + /*! + * \brief List the names of outputs for this symbol. + * For normal operators, it is usually symbol node name + "_output" + * \return get the descriptions of outputs for this symbol. + */ + std::vector ListOutputs() const; + /*! + * \brief Compose the symbol with arguments, this changes the current symbol. + * The kwargs passed in can be in-complete, + * + * The rest of the symbols will remain the same name. + * + * \param positional arguments + * \param kwargs keyword arguments for the symbol + * \param name name of returned symbol. + */ + void Compose(const std::vector& args, + const std::unordered_map& kwargs, + const std::string& name); + /*! + * \brief Apply the symbol as a function, compose with arguments + * This is equivalent to Copy then Compose. + * \param args positional arguments for the symbol + * \param kwargs keyword arguments for the symbol + * \param name name of returned symbol. + * \return a new Symbol which is the composition of current symbol with its arguments + */ + Symbol operator () (const std::vector& args, + const std::unordered_map& kwargs, + const std::string& name) const; + /*! + * \brief Add control flow depenencies to operators involved in symbols. + * For grouped sybmbol, an error will be raised. + * This mutate current symbolic Node. + * + * \param src The symbols to depend on. + */ + void AddControlDeps(const Symbol& src); + /* + * \brief Get all the internal nodes of the symbol. + * \return symbol A new symbol whose output contains all the outputs of the symbols + * Including input variables and intermediate outputs. + */ + Symbol GetInternals() const; + /*! + * \brief set additional attributes to current node. + * This only works for symbol with outputs from single operators. + * For grouped sybmbol, an error will be raised. + * + * This function mutate the node's symbol and is not recommended. + * + * \param key the key of the attribute + * \param value the value of the attribute. + */ + void SetAttrs(const std::vector >& attrs); + /*! + * \brief Get attribute dictionary from the symbol. + * For grouped sybmbol, an error will be raised. + * \param option If recursive is set, the attributes of all children are retrieved, + * The name of symbol will be pre-pended to each key. + * \return The created attribute. + */ + std::unordered_map ListAttr(ListAttrOption option) const; + /*! + * \brief create symbolic functor(AtomicSymbol) by given operator and attributes. + * \param op_name The name of the operator. + * \param attrs The additional attributes. + * + * \return Symbol that can be used to call compose further. + */ + static Symbol CreateFunctor(const std::string& op_name, + const std::unordered_map& attrs); + /*! + * \brief create variable symbol node + * \param name name of the variable + * \return the new variable + */ + static Symbol CreateVariable(const std::string& name); + /*! + * \brief create equivalence of symbol by grouping the symbols together + * \param symbols list of symbols + * \return the grouped symbol + */ + static Symbol CreateGroup(const std::vector& symbols); +}; + +} // namespace nnvm + +#endif // NNVM_SYMBOLIC_H_ diff --git a/nnvm/src/core/graph_attr_types.cc b/nnvm/src/core/graph_attr_types.cc index 43351f544835..aaefa0626131 100644 --- a/nnvm/src/core/graph_attr_types.cc +++ b/nnvm/src/core/graph_attr_types.cc @@ -13,7 +13,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { entry_rptr_.push_back(0); std::vector inputs_rptr{0}, control_rptr{0}; - g.DFSVisit([this, &inputs_rptr, &control_rptr] + DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr] (const std::shared_ptr& n) { CHECK_LT(nodes_.size(), std::numeric_limits::max()); uint32_t nid = static_cast(nodes_.size()); diff --git a/nnvm/src/core/op.cc b/nnvm/src/core/op.cc index 57514aa7fbcb..a5518ace6be1 100644 --- a/nnvm/src/core/op.cc +++ b/nnvm/src/core/op.cc @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -23,7 +24,7 @@ struct OpManager { // global operator counter std::atomic op_counter{0}; // storage of additional attribute table. - std::unordered_map attr; + std::unordered_map > attr; // get singleton of the static OpManager* Global() { static OpManager inst; @@ -46,24 +47,24 @@ const Op* Op::Get(const std::string& name) { } // Get attribute map by key -const any& Op::GetAttrMap(const std::string& key) { - // assume no operator registration during - // the execution phase. - const auto& dict = OpManager::Global()->attr; +const any* Op::GetAttrMap(const std::string& key) { + auto& dict = OpManager::Global()->attr; auto it = dict.find(key); - CHECK(it != dict.end() && it->first == key) - << "Cannot find Operator attribute " << key - << " for any operator"; - return it->second; + if (it != dict.end()) { + return it->second.get(); + } else { + return nullptr; + } } -// update attribute map by updater function. +// update attribute map void Op::UpdateAttrMap(const std::string& key, std::function updater) { OpManager* mgr = OpManager::Global(); std::lock_guard(mgr->mutex); - any& value = mgr->attr[key]; - updater(&value); + std::unique_ptr& value = mgr->attr[key]; + if (value.get() == nullptr) value.reset(new any()); + if (updater != nullptr) updater(value.get()); } } // namespace nnvm diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc new file mode 100644 index 000000000000..03a6f065b008 --- /dev/null +++ b/nnvm/src/core/symbolic.cc @@ -0,0 +1,369 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file symbolic.cc + * \brief Symbolic graph composition API. + */ +#include +#include +#include + +namespace nnvm { + +namespace symbol_constants { +const char *kNamespaceSeparator = "_"; +} // namespace symbol_constants + + +inline std::string DefaultVarName(const std::string &op_name, + const std::string &arg_name) { + if (op_name.length() == 0) { + return arg_name; + } else { + return op_name + '_' + arg_name; + } +} + +inline void KeywordArgumentMismatch(const char *source, + const std::vector& user_args, + const array_view& args) { + std::unordered_set keys(args.begin(), args.end()); + std::ostringstream head, msg; + msg << "\nCandidate arguments:\n"; + for (size_t i = 0; i < args.size(); ++i) { + msg << "\t[" << i << ']' << args[i] << '\n'; + } + + for (const auto& key : user_args) { + if (keys.count(key) == 0) { + LOG(FATAL) << source + << "Keyword argument name " << key << " not found." + << msg.str(); + } + } +} + +template +inline std::vector GetKeys( + const std::unordered_map& kwargs) { + std::vector keys(kwargs.size()); + std::transform(kwargs.begin(), kwargs.end(), keys.begin(), + [](decltype(*kwargs.begin())& kv) { return kv.first; }); + return keys; +} + +// whether the symbol is atomic functor +inline bool IsAtomic(const std::vector& outputs) { + return outputs.size() == 1 && outputs[0].node->inputs.size() == 0; +} + +// public functions +Symbol Symbol::Copy() const { + std::unordered_map > old_new; + // use DFSVisit to copy all the nodes + DFSVisit(this->outputs, [&old_new](const std::shared_ptr& node) { + old_new[node.get()] = std::make_shared(*node); + }); + // connect nodes of new graph + for (const auto &kv : old_new) { + for (const NodeEntry& e : kv.first->inputs) { + Node *ptr = e.node.get(); + kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index}); + } + } + // set the head + Symbol ret; + for (const NodeEntry &e : outputs) { + ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index}); + } + return ret; +} + +void Symbol::Print(std::ostream &os) const { + if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0) { + os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n'; + } else { + // use DFSVisit to copy all the nodes + os << "Outputs:\n"; + for (size_t i = 0; i < outputs.size(); ++i) { + os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name + << '(' << outputs[i].index << ")\n"; + } + DFSVisit(this->outputs, [&os](const std::shared_ptr& node) { + if (node->is_variable()) { + os << "Variable:" << node->attrs.name << '\n'; + } else { + os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n' + << "Inputs:\n"; + for (size_t i = 0; i < node->inputs.size(); ++i) { + os << "\targ[" << i << "]=" << node->inputs[i].node->attrs.name + << '(' << node->inputs[i].index << ")\n"; + } + os << "Attrs:\n"; + for (auto &kv : node->attrs.dict) { + os << '\t' << kv.first << '=' << kv.second << '\n'; + } + } + }); + } +} + +Symbol Symbol::operator[] (size_t index) const { + size_t nreturn = outputs.size(); + CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; + if (nreturn == 1) { + return *this; + } else { + Symbol s; + s.outputs.push_back(outputs[index]); + return s; + } +} + +std::vector Symbol::ListArguments() const { + std::vector ret; + DFSVisit(this->outputs, [&ret](const std::shared_ptr &node) { + if (node->is_variable()) { + ret.push_back(node->attrs.name); + } + }); + return ret; +} + +std::vector Symbol::ListOutputs() const { + static auto& flist_ouputs = Op::GetAttr("FListOutputNames"); + std::vector ret; + for (auto &head : outputs) { + if (head.node->is_variable()) { + ret.push_back(head.node->attrs.name); + } else { + const std::string& hname = head.node->attrs.name; + std::string rname; + FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr); + if (fn != nullptr) { + rname = fn(head.node->attrs)[head.index]; + } else { + rname = "output"; + if (head.node->num_outputs() != 1) { + std::ostringstream os; + os << rname << head.index; + rname = os.str(); + } + } + if (hname.length() == 0) { + ret.push_back(std::move(rname)); + } else { + ret.push_back(hname + '_' + rname); + } + } + } + return ret; +} + +// compositional logic +void Symbol::Compose(const std::vector& args, + const std::unordered_map& kwargs, + const std::string& name) { + CHECK_EQ(outputs.size(), 1) + << "Only composition of value function is supported currently"; + CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed"; + // parameter check. + for (size_t i = 0; i < args.size(); ++i) { + CHECK_EQ(args[i].outputs.size(), 1) + << "Argument " << i << " is a tuple, single value is required"; + } + for (const auto& kv : kwargs) { + CHECK_EQ(kv.second.outputs.size(), 1) + << "Keyword Argument " << kv.first << " is a tuple, single value is required"; + } + // assign new name + outputs[0].node->attrs.name = name; + + // Atomic functor composition. + if (IsAtomic(outputs)) { + Node* n = outputs[0].node.get(); + uint32_t n_req = n->num_inputs(); + + if (n_req != kVarg) { + n->inputs.resize(n_req); + CHECK_LE(args.size(), n_req) + << "Incorrect number of arguments, requires " << n_req + << ", provided " << args.size(); + for (size_t i = 0; i < args.size(); ++i) { + n->inputs[i] = args[i].outputs[0]; + } + // switch to keyword argument matching + if (args.size() != n_req) { + static auto& flist_inputs = Op::GetAttr("FListInputNames"); + FListInputNames fn = flist_inputs.get(n->op, nullptr); + auto arg_names = (fn == nullptr) ? std::vector{"data"} : fn(n->attrs); + CHECK_EQ(arg_names.size(), n_req); + + size_t nmatched = 0; + for (size_t i = args.size(); i < n_req; ++i) { + auto it = kwargs.find(arg_names[i]); + if (it != kwargs.end() && it->first == arg_names[i]) { + n->inputs[i] = it->second.outputs[0]; + ++nmatched; + } else { + n->inputs[i] = NodeEntry{Node::Create(), 0}; + n->inputs[i].node->attrs.name = DefaultVarName(name, arg_names[i]); + } + } + + if (nmatched != kwargs.size()) { + n->inputs.clear(); + std::vector keys = GetKeys(kwargs); + array_view view(dmlc::BeginPtr(arg_names) + args.size(), + dmlc::BeginPtr(arg_names) + arg_names.size()); + KeywordArgumentMismatch("Symbol.Compose", keys, view); + } + } + } else { + CHECK_EQ(kwargs.size(), 0) << "Variable length function do not accept kwargs"; + n->inputs.reserve(args.size()); + for (const Symbol& s : args) { + n->inputs.push_back(s.outputs[0]); + } + } + } else { + // general composition + CHECK_EQ(args.size(), 0) + << "General composition only support kwargs for now"; + size_t nmatched = 0; + size_t arg_counter = 0; + std::unordered_map replace_map; + // replace map stores the existing replacement plan for arguments node + auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map] + (const std::shared_ptr &node) { + if (node->is_variable()) { + if (arg_counter < args.size()) { + replace_map[node.get()] = &(args[arg_counter].outputs[0]); + ++arg_counter; + } else { + // match kwargs + auto kit = kwargs.find(node->attrs.name); + if (kit != kwargs.end()) { + replace_map[node.get()] = &(kit->second.outputs[0]); + ++nmatched; + } + } + } + }; + DFSVisit(this->outputs, find_replace_map); + + if (nmatched == kwargs.size() && arg_counter < args.size()) { + std::vector > replace_plan; + auto find_replace_plan = [&replace_map, &replace_plan] + (const std::shared_ptr &node) { + // visit all the childs, find possible replacement + for (size_t i = 0; i < node->inputs.size(); ++i) { + NodeEntry *e = &(node->inputs[i]); + if (e->node->is_variable()) { + auto iter = replace_map.find(e->node.get()); + if (iter != replace_map.end()) { + replace_plan.push_back(std::make_pair(e, iter->second)); + } + } + } + }; + DFSVisit(this->outputs, find_replace_plan); + + for (const auto& kv : replace_plan) { + *(kv.first) = *(kv.second); + } + } else { + std::vector keys = GetKeys(kwargs); + std::vector arg_names = ListArguments(); + array_view view(dmlc::BeginPtr(arg_names) + arg_counter, + dmlc::BeginPtr(arg_names) + arg_names.size()); + KeywordArgumentMismatch("Symbol.Compose", keys, ListArguments()); + } + } +} + +Symbol Symbol::operator () (const std::vector& args, + const std::unordered_map& kwargs, + const std::string& name) const { + Symbol s = this->Copy(); + s.Compose(args, kwargs, name); + return s; +} + +void Symbol::AddControlDeps(const Symbol& src) { + CHECK_EQ(outputs.size(), 1) + << "AddControlDeps only works for nongrouped symbol"; + Node* n = outputs[0].node.get(); + for (const NodeEntry& sp : src.outputs) { + n->control_deps.push_back(sp.node); + } +} + +Symbol Symbol::GetInternals() const { + Symbol ret; + DFSVisit(this->outputs, [&ret](const std::shared_ptr& node) { + Node* n = node.get(); + uint32_t nout = n->num_outputs(); + for (uint32_t i = 0; i < nout; ++i) { + ret.outputs.emplace_back(NodeEntry{node, i}); + } + }); + return ret; +} + +void Symbol::SetAttrs(const std::vector >& attrs) { + CHECK_EQ(outputs.size(), 1) + << "SetAttrs only works for nongrouped symbol"; + Node* n = outputs[0].node.get(); + for (const auto& kv : attrs) { + n->attrs.dict[kv.first] = kv.second; + } + if (n->op->attr_parser != nullptr) { + (*n->op->attr_parser)(&(n->attrs)); + } +} + +std::unordered_map Symbol::ListAttr(ListAttrOption option) const { + if (option == kRecursive) { + std::unordered_map ret; + DFSVisit(this->outputs, [&ret](const std::shared_ptr& n) { + for (const auto& it : n->attrs.dict) { + ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; + } + }); + return ret; + } else { + return outputs[0].node->attrs.dict; + } +} + +Symbol Symbol::CreateFunctor(const std::string& op_name, + const std::unordered_map& attrs) { + Symbol s; + std::shared_ptr n = Node::Create(); + n->op = Op::Get(op_name); + n->attrs.dict = attrs; + if (n->op->attr_parser != nullptr) { + (*n->op->attr_parser)(&(n->attrs)); + } + s.outputs.emplace_back(NodeEntry{std::move(n), 0}); + return s; +} + +Symbol Symbol::CreateGroup(const std::vector &symbols) { + Symbol ret; + for (const auto &s : symbols) { + ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end()); + } + return ret; +} + +Symbol Symbol::CreateVariable(const std::string& name) { + Symbol s; + std::shared_ptr n = Node::Create(); + n->op = nullptr; + n->attrs.name = name; + s.outputs.emplace_back(NodeEntry{std::move(n), 0}); + return s; +} + +} // namespace nnvm diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 9e82f1bd59cc..f05e4873e2bf 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -158,7 +158,7 @@ Graph LoadJSON(const Graph& src) { Graph SaveJSON(const Graph& src) { JSONGraph jgraph; std::unordered_map node2index; - src.DFSVisit([&node2index, &jgraph](const std::shared_ptr& n) { + DFSVisit(src.outputs, [&node2index, &jgraph](const std::shared_ptr& n) { uint32_t nid = static_cast(jgraph.nodes.size()); node2index[n.get()] = nid; if (n->is_variable()) { diff --git a/nnvm/src/test_main.cc b/nnvm/src/test_main.cc index 6a6946171b38..51d4ec1513cf 100644 --- a/nnvm/src/test_main.cc +++ b/nnvm/src/test_main.cc @@ -8,7 +8,7 @@ void test_op() { using namespace nnvm; auto add = Op::Get("add"); - auto nick = Op::GetAttr("nick_name"); + static auto& nick = Op::GetAttr("nick_name"); LOG(INFO) << "nick=" << nick[add]; } @@ -35,9 +35,7 @@ void test_tuple() { void test_graph() { - nnvm::Graph g; - g.DFSVisit([](const std::shared_ptr& n){ - }); + nnvm::Symbol s; } int main() { test_tuple(); diff --git a/nnvm/tests/cpp/op_test.cc b/nnvm/tests/cpp/op_test.cc index c0a8913f4aef..816305c609ca 100644 --- a/nnvm/tests/cpp/op_test.cc +++ b/nnvm/tests/cpp/op_test.cc @@ -1,19 +1,19 @@ #include #include -#include +#include #include -NNGRAPH_REGISTER_OP(add) +NNVM_REGISTER_OP(add) .describe("add two data together") .set_num_inputs(2) .attr("inplace_pair", std::make_pair(0, 0)); -NNGRAPH_REGISTER_OP(add) +NNVM_REGISTER_OP(add) .attr("nick_name", "plus"); TEST(Op, GetAttr) { - using namespace nngraph; + using namespace nnvm; auto add = Op::Get("add"); auto nick = Op::GetAttr("nick_name"); diff --git a/nnvm/tests/cpp/tuple_test.cc b/nnvm/tests/cpp/tuple_test.cc index 4750e01d2541..496e62ebe9a5 100644 --- a/nnvm/tests/cpp/tuple_test.cc +++ b/nnvm/tests/cpp/tuple_test.cc @@ -1,10 +1,10 @@ #include #include -#include +#include TEST(Tuple, Basic) { - using nngraph::Tuple; - using nngraph::TShape; + using nnvm::Tuple; + using nnvm::TShape; Tuple x{1, 2, 3}; Tuple y{1, 2, 3, 5, 6}; x = std::move(y); @@ -17,7 +17,7 @@ TEST(Tuple, Basic) { std::istringstream is(os.str()); is >> y; CHECK_EQ(x, y); - Tuple ss{1, 2, 3}; + Tuple ss{1, 2, 3}; TShape s = ss; s = std::move(ss); CHECK((s == TShape{1, 2, 3}));