diff --git a/src/pass/infer_shape_type.cc b/src/pass/infer_shape_type.cc index a18528b5b465..445787f7e13f 100644 --- a/src/pass/infer_shape_type.cc +++ b/src/pass/infer_shape_type.cc @@ -11,15 +11,16 @@ namespace nnvm { namespace pass { namespace { -template +template Graph InferAttr(Graph &&ret, - const AttrType default_val, + const AttrType empty_val, const char* infer_name, const char* input_name, const char* attr_key_name, const char* attr_name, const char* unknown_name, - IsNone fis_none) { + IsNone fis_none, + FDefault fdefault) { using AttrVector = std::vector; const IndexedGraph& idx = ret.indexed_graph(); static auto& finfer_shape = @@ -31,7 +32,7 @@ Graph InferAttr(Graph &&ret, if (ret.attrs.count(attr_name) != 0) { rshape = ret.MoveCopyAttr(attr_name); } else { - rshape.resize(idx.num_node_entries(), default_val); + rshape.resize(idx.num_node_entries(), empty_val); } if (ret.attrs.count(input_name) != 0) { @@ -51,12 +52,12 @@ Graph InferAttr(Graph &&ret, // erase the provided arguments ret.attrs.erase(attr_key_name); } - // Temp space for shape inference. std::vector ishape, oshape; - // number of completed nodes - size_t num_unknown = 0; - for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + size_t num_unknown; + + // inference step function for nid + auto infer_step = [&](uint32_t nid) { const auto& inode = idx[nid]; const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_outputs = inode.source->num_outputs(); @@ -72,27 +73,6 @@ Graph InferAttr(Graph &&ret, CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; } } - } else if (finfer_shape.count(inode.source->op())) { - // Forward operator inference. - ishape.resize(num_inputs, default_val); - for (uint32_t i = 0; i < ishape.size(); ++i) { - ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; - } - oshape.resize(num_outputs, default_val); - for (uint32_t i = 0; i < oshape.size(); ++i) { - oshape[i] = rshape[idx.entry_id(nid, i)]; - } - // Call inference function of the operator. - bool forward_known = finfer_shape[inode.source->op()]( - inode.source->attrs, &ishape, &oshape); - num_unknown += !forward_known; - // Save to the result map. - for (uint32_t i = 0; i < num_inputs; ++i) { - rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; - } - for (uint32_t i = 0; i < num_outputs; ++i) { - rshape[idx.entry_id(nid, i)] = oshape[i]; - } } else if (backward_map.count(inode.source->op())) { // Backward operator inference. CHECK_GE(inode.control_deps.size(), 1) @@ -111,6 +91,47 @@ Graph InferAttr(Graph &&ret, if (fis_none(rshape[idx.entry_id(nid, i)])) known = false; } num_unknown += !known; + } else { + bool forward_known = true; + // Forward operator inference. + ishape.resize(num_inputs, empty_val); + for (uint32_t i = 0; i < ishape.size(); ++i) { + ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; + if (fis_none(ishape[i])) forward_known = false; + } + oshape.resize(num_outputs, empty_val); + for (uint32_t i = 0; i < oshape.size(); ++i) { + oshape[i] = rshape[idx.entry_id(nid, i)]; + if (fis_none(oshape[i])) forward_known = false; + } + if (!forward_known) { + auto finfer = finfer_shape.get(inode.source->op(), fdefault); + CHECK(finfer != nullptr) + << "Attribute " << infer_name + << " is not registed by op " << inode.source->op()->name; + // Call inference function of the operator. + forward_known = finfer(inode.source->attrs, &ishape, &oshape); + } + num_unknown += !forward_known; + // Save to the result map. + for (uint32_t i = 0; i < num_inputs; ++i) { + rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; + } + for (uint32_t i = 0; i < num_outputs; ++i) { + rshape[idx.entry_id(nid, i)] = oshape[i]; + } + } + }; + + num_unknown = 0; + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + infer_step(nid); + } + if (num_unknown != 0) { + num_unknown = 0; + // backward inference + for (uint32_t i = idx.num_nodes(); i != 0; --i) { + infer_step(i - 1); } } // set the shapes @@ -127,19 +148,48 @@ NNVM_REGISTER_PASS(InferShape) std::move(ret), TShape(), "FInferShape", "shape_inputs", "shape_attr_key", "shape", "shape_num_unknown_nodes", - [](const TShape& s) { return s.ndim() == 0; }); + [](const TShape& s) { return s.ndim() == 0; }, + nullptr); }) .set_change_graph(false) .provide_graph_attr("shape"); +// inference fucntion for same type +inline bool SameType(const NodeAttrs& attrs, + std::vector *iattr, + std::vector *oattr) { + int def_v = -1; + for (int v : *oattr) { + if (v != -1) { + def_v = v; break; + } + } + if (def_v == -1) { + for (int v : *iattr) { + if (v != -1) { + def_v = v; break; + } + } + } + if (def_v == -1) return false; + for (int& v : *oattr) { + v = def_v; + } + for (int& v : *iattr) { + v = def_v; + } + return true; +} + NNVM_REGISTER_PASS(InferType) .describe("Infer the dtype of each node entries.") .set_body([](Graph ret) { return InferAttr( - std::move(ret), 0, + std::move(ret), -1, "FInferType", "dtype_inputs", "dtype_attr_key", "dtype", "dtype_num_unknown_nodes", - [](const int t) { return t == -1; }); + [](const int t) { return t == -1; }, + SameType); }) .set_change_graph(false) .provide_graph_attr("dtype");