Skip to content

Commit

Permalink
[CPU] Impl extract_image_patches cache (openvinotoolkit#9525)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangYiIntel authored Jan 11, 2022
1 parent 1a3d0ad commit 986f0ea
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "list.hpp"
#include <cpu/x64/jit_generator.hpp>
#include "caseless.hpp"
#include <common/primitive_hashing_utils.hpp>

using namespace MKLDNNPlugin;
using namespace InferenceEngine;
Expand Down Expand Up @@ -290,6 +291,40 @@ bool MKLDNNExtractImagePatchesNode::isSupportedOperation(const std::shared_ptr<c
return true;
}

namespace {
struct ExtractImagePatchesKey {
VectorDims inDims;
VectorDims outDims;
VectorDims kSizes;
VectorDims strides;
VectorDims rates;
MKLDNNExtractImagePatchesNode::ExtImgPatcherPadType padType;
size_t prcSize;
size_t hash() const;
bool operator==(const ExtractImagePatchesKey& rhs) const;
};

size_t ExtractImagePatchesKey::hash() const {
using namespace dnnl::impl::primitive_hashing;
using namespace dnnl::impl;
size_t seed = 0;
seed = get_vector_hash(seed, inDims);
seed = get_vector_hash(seed, outDims);
seed = get_vector_hash(seed, kSizes);
seed = get_vector_hash(seed, strides);
seed = get_vector_hash(seed, rates);
seed = hash_combine(seed, padType);
seed = hash_combine(seed, prcSize);
return seed;
}

bool ExtractImagePatchesKey::operator==(const ExtractImagePatchesKey& rhs) const {
bool result = inDims == rhs.inDims && outDims == rhs.outDims && kSizes == rhs.kSizes && strides == rhs.strides &&
rates == rhs.rates && padType == rhs.padType && prcSize == rhs.prcSize;
return result;
}
} // namespace

MKLDNNExtractImagePatchesNode::MKLDNNExtractImagePatchesNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng,
MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) {
std::string errorMessage;
Expand Down Expand Up @@ -340,11 +375,30 @@ void MKLDNNExtractImagePatchesNode::prepareParams() {
const auto& in_dims = getParentEdgeAt(0)->getMemory().getStaticDims();
const auto& out_dims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
const auto prcSize = getOriginalInputPrecisionAtPort(0).size();
if (mayiuse(x64::sse41)) {
execPtr = std::make_shared<ExtractImagePatchesJitExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize);
} else {
execPtr = std::make_shared<ExtractImagePatchesRefExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize);
}
ExtractImagePatchesKey key = {in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize};
const auto isJit = mayiuse(x64::sse41);
auto buildExecutor = [&isJit](const ExtractImagePatchesKey& key) -> executorPtr {
if (isJit) {
return std::make_shared<ExtractImagePatchesJitExecutor>(key.inDims,
key.outDims,
key.kSizes,
key.strides,
key.rates,
key.padType,
key.prcSize);
} else {
return std::make_shared<ExtractImagePatchesRefExecutor>(key.inDims,
key.outDims,
key.kSizes,
key.strides,
key.rates,
key.padType,
key.prcSize);
}
};
auto cache = getRuntimeCache();
auto result = cache->getOrCreate(key, buildExecutor);
execPtr = result.first;
}

void MKLDNNExtractImagePatchesNode::initSupportedPrimitiveDescriptors() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,13 @@ class MKLDNNExtractImagePatchesNode : public MKLDNNNode {
void prepareParams() override;

static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;

private:
enum class ExtImgPatcherPadType {
VALID,
SAME_LOWER,
SAME_UPPER
};

private:
std::vector<size_t> _ksizes;
std::vector<size_t> _strides;
std::vector<size_t> _rates;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ const std::vector<InputShape> inputShapes = {
// dynamic
{-1, -1, -1, -1},
// static
{{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}}
{{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}, {2, 3, 13, 37}}
},
InputShape{
// dynamic
{{5, 15}, {6, 17}, {10, 15}, {13, 16}},
// static
{{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}}
{{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}, {5, 17, 10, 15}}
},
};

Expand Down

0 comments on commit 986f0ea

Please sign in to comment.