From d1325be02e1a4fd7aef8e16ed70ed37e84ab5e8e Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 5 Jun 2019 16:27:16 -0700 Subject: [PATCH] Ghost nodes in NNVM graph (#3290) --- nnvm/include/nnvm/op_attr_types.h | 11 +++++++++++ nnvm/src/core/graph.cc | 3 +++ 2 files changed, 14 insertions(+) diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index 976ad929f496..ad328c30312a 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -136,6 +136,17 @@ using FInferType = FInferNodeEntryAttr; */ using TIsBackward = bool; +/*! + * \brief Whether this op is a ghost node. + * If TIsGhost is true: + * - The node with this op will not be visible in the indexed graph. + * + * \note Register under "TIsGhost" + * This enables shape/type inference for backward nodes when + * fusion is present. + */ +using TIsGhost = bool; + /*! * \brief Get possible inplace options. * This function enables optimization to reuse memory of inputs in output. diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 92ff98618ec8..29149f48fdb0 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -76,6 +76,8 @@ IndexedGraph::IndexedGraph(const Graph &g) { DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] (const NodePtr& n) { + const auto& is_ghost = Op::GetAttr("TIsGhost"); + if (!n->is_variable() && is_ghost.get(n->op(), false)) return; CHECK_LT(nodes_.size(), std::numeric_limits::max()); uint32_t nid = static_cast(nodes_.size()); CHECK(n); @@ -103,6 +105,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { inputs_rptr.push_back(input_entries_.size()); // control deps for (const auto& nptr : n->control_deps) { + if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; auto it = node2index_.find(nptr.get()); CHECK(it != node2index_.end() && it->first == nptr.get()); control_deps_.push_back(it->second);