Skip to content

Commit

Permalink
Deconvolution getDstMemDesc fix
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick authored and mandrono committed Jul 20, 2021
1 parent 7c3865d commit 7da61dd
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,11 @@ bool BlockedMemoryDesc::isCompatible(const BlockedMemoryDesc& rhs) const {
return false;
}

if (this->getOrder() != rhs.getOrder()) {
if (!isEqualOrUndefined(this->getOrder(), rhs.getOrder())) {
return false;
}

return !(this->getOffsetPadding() != rhs.getOffsetPadding() &&
this->getOffsetPadding() != Shape::UNDEFINED_DIM && rhs.getOffsetPadding() != Shape::UNDEFINED_DIM);
return dimsEqualWeak(this->getOffsetPadding(), rhs.getOffsetPadding());
}

size_t BlockedMemoryDesc::getMemSizeImp() const {
Expand Down
2 changes: 1 addition & 1 deletion inference-engine/src/mkldnn_plugin/mkldnn_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ bool MKLDNNMemoryDesc::isCompatible(const BlockedMemoryDesc &rhs) const {

size_t blk_offset0 = desc.data.offset0;

return !(blk_offset0 != rhs.getOffsetPadding() && blk_offset0 != Shape::UNDEFINED_DIM && rhs.getOffsetPadding() != Shape::UNDEFINED_DIM);
return dimsEqualWeak(blk_offset0, rhs.getOffsetPadding());
}

bool MKLDNNMemoryDesc::checkGeneralLayout(GeneralLayout layoutType) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,8 @@ std::unique_ptr<MKLDNNMemoryDesc> MKLDNNDeconvolutionNode::getSrcMemDesc(mkldnn:
}

std::unique_ptr<MKLDNNMemoryDesc> MKLDNNDeconvolutionNode::getDstMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) {
MKLDNNMemoryDesc desc = isInt8 ? MKLDNNMemoryDesc(primitive_desc_it.dst_desc(idx))
: MKLDNNMemoryDesc(primitive_desc_it.diff_src_desc(idx));
return MKLDNNPlugin::make_unique<MKLDNNMemoryDesc>(getChildEdgeAt(idx)->getShape().getStaticMklDims(), desc.getDataType(), desc.getFormat());
return isInt8 ? MKLDNNPlugin::make_unique<MKLDNNMemoryDesc>(primitive_desc_it.dst_desc(idx)) :
MKLDNNPlugin::make_unique<MKLDNNMemoryDesc>(primitive_desc_it.diff_src_desc(idx));
}

InferenceEngine::Precision MKLDNNDeconvolutionNode::getRuntimePrecision() const {
Expand Down

0 comments on commit 7da61dd

Please sign in to comment.