From 3865410a3bc68429cdf7348f104b08a922012683 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 16 May 2018 07:42:55 +0530 Subject: [PATCH] Better to check Infer result with topi results at build time instead of leaving to a runtime error. (#476) --- nnvm/src/compiler/compile_engine.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index 1ef39f851203..a9d4aa2d016a 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -202,6 +202,22 @@ class CompileEngine { Array out = fcompute[inode.source->op()]( inode.source->attrs, op_inputs, out_info); CHECK_EQ(out.size(), inode.source->num_outputs()); + + // check output dimentions also match + // This check is to make sure the NNVM operator Infer match with Compute result. + // Missing this check may pass the build but leads to runtime errors. + for (uint32_t i = 0; i < out.size(); ++i) { + CHECK_EQ(out[i].ndim(), out_info[i].ndim()) << inode.source->op()->name; + tvm::Tensor inferred_tensor = out[i]; + tvm::Tensor computed_tensor = out_info[i]; + for (uint32_t j = 0; j < inferred_tensor->shape.size(); ++j) { + if ((as_const_int(inferred_tensor->shape[j])) && + (as_const_int(computed_tensor->shape[j]))) + CHECK_EQ((*as_const_int(inferred_tensor->shape[j])), + (*as_const_int(computed_tensor->shape[j]))) << inode.source->op()->name; + } + } + // schedule on root node, and use master's schedule for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { uint32_t eid = idx.entry_id(nid, index);