Skip to content

Commit

Permalink
[Infer] More robust inference, support backward inference (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Sep 22, 2016
1 parent 79cf63b commit e443353
Showing 1 changed file with 82 additions and 32 deletions.
114 changes: 82 additions & 32 deletions src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ namespace nnvm {
namespace pass {
namespace {

template<typename AttrType, typename IsNone>
template<typename AttrType, typename IsNone, typename FDefault>
Graph InferAttr(Graph &&ret,
const AttrType default_val,
const AttrType empty_val,
const char* infer_name,
const char* input_name,
const char* attr_key_name,
const char* attr_name,
const char* unknown_name,
IsNone fis_none) {
IsNone fis_none,
FDefault fdefault) {
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Expand All @@ -31,7 +32,7 @@ Graph InferAttr(Graph &&ret,
if (ret.attrs.count(attr_name) != 0) {
rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
} else {
rshape.resize(idx.num_node_entries(), default_val);
rshape.resize(idx.num_node_entries(), empty_val);
}

if (ret.attrs.count(input_name) != 0) {
Expand All @@ -51,12 +52,12 @@ Graph InferAttr(Graph &&ret,
// erase the provided arguments
ret.attrs.erase(attr_key_name);
}

// Temp space for shape inference.
std::vector<AttrType> ishape, oshape;
// number of completed nodes
size_t num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
size_t num_unknown;

// inference step function for nid
auto infer_step = [&](uint32_t nid) {
const auto& inode = idx[nid];
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
Expand All @@ -72,27 +73,6 @@ Graph InferAttr(Graph &&ret,
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
}
}
} else if (finfer_shape.count(inode.source->op())) {
// Forward operator inference.
ishape.resize(num_inputs, default_val);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
}
oshape.resize(num_outputs, default_val);
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)];
}
// Call inference function of the operator.
bool forward_known = finfer_shape[inode.source->op()](
inode.source->attrs, &ishape, &oshape);
num_unknown += !forward_known;
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
} else if (backward_map.count(inode.source->op())) {
// Backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
Expand All @@ -111,6 +91,47 @@ Graph InferAttr(Graph &&ret,
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
}
num_unknown += !known;
} else {
bool forward_known = true;
// Forward operator inference.
ishape.resize(num_inputs, empty_val);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
if (fis_none(ishape[i])) forward_known = false;
}
oshape.resize(num_outputs, empty_val);
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)];
if (fis_none(oshape[i])) forward_known = false;
}
if (!forward_known) {
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
CHECK(finfer != nullptr)
<< "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name;
// Call inference function of the operator.
forward_known = finfer(inode.source->attrs, &ishape, &oshape);
}
num_unknown += !forward_known;
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
}
};

num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid);
}
if (num_unknown != 0) {
num_unknown = 0;
// backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1);
}
}
// set the shapes
Expand All @@ -127,19 +148,48 @@ NNVM_REGISTER_PASS(InferShape)
std::move(ret), TShape(),
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; });
[](const TShape& s) { return s.ndim() == 0; },
nullptr);
})
.set_change_graph(false)
.provide_graph_attr("shape");

// inference fucntion for same type
inline bool SameType(const NodeAttrs& attrs,
std::vector<int> *iattr,
std::vector<int> *oattr) {
int def_v = -1;
for (int v : *oattr) {
if (v != -1) {
def_v = v; break;
}
}
if (def_v == -1) {
for (int v : *iattr) {
if (v != -1) {
def_v = v; break;
}
}
}
if (def_v == -1) return false;
for (int& v : *oattr) {
v = def_v;
}
for (int& v : *iattr) {
v = def_v;
}
return true;
}

NNVM_REGISTER_PASS(InferType)
.describe("Infer the dtype of each node entries.")
.set_body([](Graph ret) {
return InferAttr<int>(
std::move(ret), 0,
std::move(ret), -1,
"FInferType", "dtype_inputs", "dtype_attr_key",
"dtype", "dtype_num_unknown_nodes",
[](const int t) { return t == -1; });
[](const int t) { return t == -1; },
SameType);
})
.set_change_graph(false)
.provide_graph_attr("dtype");
Expand Down

0 comments on commit e443353

Please sign in to comment.