diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 0f0b8e79b9c2..f77e70085fbc 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -173,7 +173,10 @@ Graph PlanMemory(Graph ret) { for (auto& kv : inplace_pairs) { uint32_t eid_out = idx.entry_id(nid, kv.second); uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); - if (ref_count[eid_in] == 1 && storage[eid_in] != GraphAllocator::kBadStorageID) { + if (ref_count[eid_in] == 1 && + storage[eid_in] != GraphAllocator::kBadStorageID && + shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && + dtype_vec[eid_out] == dtype_vec[eid_in]) { // inplace optimization storage[eid_out] = storage[eid_in]; ref_count[eid_in] = 0;