diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 3e89bb8a37d1..b2d5a5254939 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -255,20 +255,19 @@ void NDArray::set_fresh_out_grad(bool state) const { } #if MXNET_USE_MKLDNN == 1 -static inline bool same_shape(const TShape &shape, mkldnn_dims_t dims, int ndims) { +static inline bool same_shape(const TShape &shape, mkldnn::memory::primitive_desc pd) { + int ndims = pd.desc().data.ndims; if (shape.ndim() != ndims) return false; for (int i = 0; i < ndims; i++) - if (shape[i] != dims[i]) + if (shape[i] != pd.desc().data.dims[i]) return false; return true; } void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) { - if (Mkl_mem_ && same_shape(shape, Mkl_mem_->get_primitive_desc().desc().data.dims, - Mkl_mem_->get_primitive_desc().desc().data.ndims)) { + if (Mkl_mem_ && same_shape(shape, Mkl_mem_->get_primitive_desc())) return; - } mkldnn::memory::dims dims(shape.ndim()); for (size_t i = 0; i < dims.size(); i++) @@ -304,6 +303,10 @@ static int GetTypeSize(int dtype) { std::shared_ptr NDArray::GetMKLDNNData( const mkldnn::memory::primitive_desc &desc) const { + // If the array size doesn't match, we should reset MKL memory. + if (ptr_->Mkl_mem_ && !same_shape(shape(), ptr_->Mkl_mem_->get_primitive_desc())) + ptr_->Mkl_mem_ = nullptr; + if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; return nullptr; @@ -319,6 +322,10 @@ std::shared_ptr NDArray::GetMKLDNNData( std::shared_ptr NDArray::GetMKLDNNDataReorder( const mkldnn::memory::primitive_desc &desc) const { + // If the array size doesn't match, we should reset MKL memory. + if (ptr_->Mkl_mem_ && !same_shape(shape(), ptr_->Mkl_mem_->get_primitive_desc())) + ptr_->Mkl_mem_ = nullptr; + if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; return nullptr; @@ -388,6 +395,7 @@ void NDArray::SetTBlob() const { } else if (stype == kMKLDNNStorage) { // TODO we may really need to convert format. CHECK_EQ(byte_offset_, 0); + ptr_->SetMKLMem(shape_, dtype_); dptr = (char *) ptr_->Mkl_mem_->get_data_handle(); #endif } else {