Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Use Dnnl executor to avoid extra dnnl primitve desc query #16372

Merged
merged 5 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,18 @@ MemoryDescPtr DnnlMemoryDesc::cloneWithNewPrecision(const InferenceEngine::Preci
}

bool DnnlMemoryDesc::isCompatible(const MemoryDesc &rhs) const {
if (MemoryDescType::Dnnl == rhs.getType()) {
return this->desc == rhs.as<DnnlMemoryDesc>()->desc;
if (MemoryDescType::Dnnl & rhs.getType()) {
auto* dnnMemDesc = rhs.as<DnnlMemoryDesc>();
return isCompatible(*dnnMemDesc);
} else {
return false;
}
}

bool DnnlMemoryDesc::isCompatible(const DnnlMemoryDesc& rhs) const {
return this->desc == rhs.desc;
}

std::string DnnlMemoryDesc::serializeFormat() const {
dnnl::impl::memory_desc_wrapper wrapped(desc.get());
if (wrapped.is_wino_desc()) {
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DnnlMemoryDesc : public virtual MemoryDesc {
MemoryDescPtr cloneWithNewPrecision(const InferenceEngine::Precision prec) const override;

bool isCompatible(const MemoryDesc& rhs) const override;
bool isCompatible(const DnnlMemoryDesc& rhs) const;

bool hasLayoutType(LayoutType layoutType) const override { return false; }

Expand Down
6 changes: 0 additions & 6 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,12 +550,6 @@ std::vector<memory::format_tag> Node::getAvailableFormatsForDims(const Shape &di
return {memory::format_tag::any};
}

void Node::execute(dnnl::stream strm) {
if (prim) {
prim.execute(strm, primArgs);
}
}

void Node::updateShapes() {
IE_ASSERT(isDynamicNode()) << "Node::updateShapes() is called to a static shape node of type: " << getTypeStr() << " with name: " << getName();
if (needShapeInfer()) {
Expand Down
10 changes: 5 additions & 5 deletions src/plugins/intel_cpu/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class Node {

void resolveInPlaceEdges();

virtual void execute(dnnl::stream strm);
virtual void execute(dnnl::stream strm) = 0;
void updateShapes();
void updateDynamicParams();
void executeDynamic(dnnl::stream strm);
Expand Down Expand Up @@ -578,7 +578,6 @@ class Node {
std::vector<NodeDesc> supportedPrimitiveDescriptors;
std::unordered_map<int, dnnl::memory> primArgs;
std::unordered_map<int, MemoryPtr> postOpsArgs;
dnnl::primitive prim;
std::vector<dnnl::primitive_desc> descs;

const GraphContext::CPtr context;
Expand Down Expand Up @@ -649,9 +648,10 @@ class Node {
IE_THROW(NotImplemented) << "[DS] prapareParams not implemented for node with type " << NameFromType(getType());
}

MemoryPtr getScratchPadMem(const const_dnnl_primitive_desc_t& pd) {
auto scratchpadMemoryDesc = DnnlExtensionUtils::query_md(pd, dnnl::query::scratchpad_md);
scratchpadMem = context->getScratchPad()->createScratchPadMem(scratchpadMemoryDesc);
MemoryPtr getScratchPadMem(const DnnlMemoryDescPtr& desc) {
if (!scratchpadMem || !scratchpadMem->getDesc().isCompatible(*desc)) {
scratchpadMem = context->getScratchPad()->createScratchPadMem(desc);
}
return scratchpadMem;
}

Expand Down
39 changes: 17 additions & 22 deletions src/plugins/intel_cpu/src/nodes/common/dnnl_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ using namespace dnnl;
namespace ov {
namespace intel_cpu {

DnnlExecutor::DnnlExecutor(const dnnl::primitive_desc& pd) {
execPrim = dnnl::primitive(pd);
src_md = DnnlExtensionUtils::makeDescriptor(pd.src_desc());
dst_md = DnnlExtensionUtils::makeDescriptor(pd.dst_desc());
wghts_md = DnnlExtensionUtils::makeDescriptor(pd.weights_desc());
scrch_md = DnnlExtensionUtils::makeDescriptor(pd.scratchpad_desc());
}

DnnlExecutor::IntermReorder::IntermReorder(const dnnl::memory::desc& descSrc,
const dnnl::memory::desc& descDst,
const dnnl::engine& engine) : m_descSrc(descSrc), m_descDst(descDst) {
Expand All @@ -20,7 +28,15 @@ void DnnlExecutor::IntermReorder::exec(dnnl::memory& memSrc, dnnl::memory& memDs
m_reorder.execute(strm, memSrc, memDst);
}

void DnnlExecutor::exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm) {
void DnnlExecutor::exec(const std::unordered_map<int, dnnl::memory>& primArgs, dnnl::stream strm) {
if (inputReorders.empty() && outputReorders.empty()) {
execPrim.execute(strm, primArgs);
} else {
reorder_exec(primArgs, strm);
}
}

void DnnlExecutor::reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm) {
for (auto &inReorder : inputReorders) {
if (primArgs.count(inReorder.first)) {
dnnl::memory memDst(inReorder.second.getDstDesc(), strm.get_engine());
Expand Down Expand Up @@ -58,27 +74,6 @@ const_dnnl_primitive_desc_t DnnlExecutor::getPrimitiveDesc() const {
return execPrim.get_primitive_desc();
}

dnnl::memory::desc DnnlExecutor::getSrcDesc() const {
auto pd = getPrimitiveDesc();
auto md = DnnlExtensionUtils::query_md(pd, dnnl::query::src_md);

return md->getDnnlDesc();
}

dnnl::memory::desc DnnlExecutor::getWeightDesc() const {
auto pd = getPrimitiveDesc();
auto md = DnnlExtensionUtils::query_md(pd, dnnl::query::weights_md);

return md->getDnnlDesc();
}

dnnl::memory::desc DnnlExecutor::getDstDesc() const {
auto pd = getPrimitiveDesc();
auto md = DnnlExtensionUtils::query_md(pd, dnnl::query::dst_md);

return md->getDnnlDesc();
}

impl_desc_type DnnlExecutor::getImplementationType() const {
auto pd = getPrimitiveDesc();
return parse_impl_name(DnnlExtensionUtils::query_impl_info_str(pd));
Expand Down
40 changes: 35 additions & 5 deletions src/plugins/intel_cpu/src/nodes/common/dnnl_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,52 @@ class DnnlExecutor {
};

public:
void exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm);
explicit DnnlExecutor(const dnnl::primitive_desc& pd);
void exec(const std::unordered_map<int, dnnl::memory>& primArgs, dnnl::stream strm);
bool needReordering() const;
virtual ~DnnlExecutor() = default;
dnnl::primitive getExecPrim() const;
const_dnnl_primitive_desc_t getPrimitiveDesc() const;
dnnl::memory::desc getSrcDesc() const;
dnnl::memory::desc getWeightDesc() const;
dnnl::memory::desc getDstDesc() const;
impl_desc_type getImplementationType() const;

DnnlMemoryDescPtr getSrcDesc() const {
return src_md;
}
DnnlMemoryDescPtr getWeightDesc() const {
return wghts_md;
}
DnnlMemoryDescPtr getDstDesc() const {
return dst_md;
}
DnnlMemoryDescPtr getScratchPadDesc() const {
return scrch_md;
}

const dnnl::memory::desc& getDnnlSrcDesc() const {
return src_md->getDnnlDesc();
}
const dnnl::memory::desc& getDnnlWeightDesc() const {
return wghts_md->getDnnlDesc();
}
const dnnl::memory::desc& getDnnlDstDesc() const {
return dst_md->getDnnlDesc();
}
const dnnl::memory::desc& getDnnlScratchPadDesc() const {
return scrch_md->getDnnlDesc();
}

protected:
void reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm);

protected:
DnnlExecutor() = default;
dnnl::primitive execPrim;
// key is the port number for the primitive that needs memory reordering
std::unordered_map<int, IntermReorder> inputReorders;
std::unordered_map<int, IntermReorder> outputReorders;
DnnlMemoryDescPtr src_md;
DnnlMemoryDescPtr wghts_md;
DnnlMemoryDescPtr dst_md;
DnnlMemoryDescPtr scrch_md;
};

} // namespace intel_cpu
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Concat : public Node {
InferenceEngine::Precision outputPrecision = InferenceEngine::Precision::FP32;
bool canExecRef = false;
static constexpr size_t MAX_RANK_REF = 6;
dnnl::primitive prim;
};

} // namespace node
Expand Down
19 changes: 8 additions & 11 deletions src/plugins/intel_cpu/src/nodes/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1495,8 +1495,7 @@ void Convolution::prepareParams() {

Node::appendPostOpArgs(*pAttrLocal, primArgs, convPostOpsArgs[preferLegacyPostOps]);

auto pd = execPtr->getPrimitiveDesc();
auto scratchpadMem = getScratchPadMem(pd);
auto scratchpadMem = getScratchPadMem(execPtr->getScratchPadDesc());
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();

#ifdef CPU_DEBUG_CAPS
Expand All @@ -1513,19 +1512,17 @@ Convolution::ConvolutionExecutor::ConvolutionExecutor(const dnnl::convolution_fo
const dnnl::memory::desc& inMemDesc,
const dnnl::memory::desc& weightMemDesc,
const dnnl::memory::desc& outMemDesc,
const dnnl::engine& engine) {
execPrim = dnnl::convolution_forward(pd);

if (inMemDesc != pd.src_desc()) {
inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, pd.src_desc(), engine)});
const dnnl::engine& engine) : DnnlExecutor(pd) {
if (inMemDesc != getDnnlSrcDesc()) {
inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, getDnnlSrcDesc(), engine)});
}

if (weightMemDesc != pd.weights_desc()) {
inputReorders.insert({DNNL_ARG_WEIGHTS, IntermReorder(weightMemDesc, pd.weights_desc(), engine)});
if (weightMemDesc != getDnnlWeightDesc()) {
inputReorders.insert({DNNL_ARG_WEIGHTS, IntermReorder(weightMemDesc, getDnnlWeightDesc(), engine)});
}

if (outMemDesc != pd.dst_desc()) {
outputReorders.insert({DNNL_ARG_DST, IntermReorder(pd.dst_desc(), outMemDesc, engine)});
if (outMemDesc != getDnnlDstDesc()) {
outputReorders.insert({DNNL_ARG_DST, IntermReorder(getDnnlDstDesc(), outMemDesc, engine)});
}
}

Expand Down
23 changes: 9 additions & 14 deletions src/plugins/intel_cpu/src/nodes/deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -991,8 +991,7 @@ void Deconvolution::prepareParams() {
}
Node::appendPostOpArgs(*pAttrLocal, primArgs, postOpsArgs);

auto pd = execPtr->getPrimitiveDesc();
auto scratchpadMem = getScratchPadMem(pd);
auto scratchpadMem = getScratchPadMem(execPtr->getScratchPadDesc());
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
#ifdef CPU_DEBUG_CAPS
if (result.second == CacheEntryBase::LookUpStatus::Miss) {
Expand Down Expand Up @@ -1094,9 +1093,7 @@ Deconvolution::DeconvExecutorDefault::DeconvExecutorDefault(const dnnl::convolut
const dnnl::memory::desc& inMemDesc,
const dnnl::memory::desc& weightMemDesc,
const dnnl::memory::desc& outMemDesc,
const dnnl::engine& engine) {
execPrim = dnnl::convolution_backward_data(pd);

const dnnl::engine& engine) : DnnlExecutor(pd) {
if (inMemDesc != pd.diff_dst_desc()) {
inputReorders.insert({DNNL_ARG_DIFF_DST, IntermReorder(inMemDesc, pd.diff_dst_desc(), engine)});
}
Expand All @@ -1114,19 +1111,17 @@ Deconvolution::DeconvExecutorInt8::DeconvExecutorInt8(const dnnl::deconvolution_
const dnnl::memory::desc& inMemDesc,
const dnnl::memory::desc& weightMemDesc,
const dnnl::memory::desc& outMemDesc,
const dnnl::engine& engine) {
execPrim = dnnl::deconvolution_forward(pd);

if (inMemDesc != pd.src_desc()) {
inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, pd.src_desc(), engine)});
const dnnl::engine& engine) : DnnlExecutor(pd) {
if (inMemDesc != getDnnlSrcDesc()) {
inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, getDnnlSrcDesc(), engine)});
}

if (weightMemDesc != pd.weights_desc()) {
inputReorders.insert({DNNL_ARG_WEIGHTS, IntermReorder(weightMemDesc, pd.weights_desc(), engine)});
if (weightMemDesc != getDnnlWeightDesc()) {
inputReorders.insert({DNNL_ARG_WEIGHTS, IntermReorder(weightMemDesc, getDnnlWeightDesc(), engine)});
}

if (outMemDesc != pd.dst_desc()) {
outputReorders.insert({DNNL_ARG_DST, IntermReorder(pd.dst_desc(), outMemDesc, engine)});
if (outMemDesc != getDnnlDstDesc()) {
outputReorders.insert({DNNL_ARG_DST, IntermReorder(getDnnlDstDesc(), outMemDesc, engine)});
}
}

Expand Down
49 changes: 17 additions & 32 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ void FullyConnected::prepareParams() {
implementationTypeIP,
useConv1x1};

auto engine = getEngine();
auto& engine = getEngine();

auto builder = [&engine](const FCKey& key) -> executorPtr {
executorPtr execPtr = nullptr;
Expand All @@ -333,7 +333,7 @@ void FullyConnected::prepareParams() {
}

if (prim_desc) {
execPtr = std::make_shared<ExecutorConv1x1>(prim_desc);
execPtr = std::make_shared<DnnlExecutor>(prim_desc);
}
}
// fallback
Expand Down Expand Up @@ -388,7 +388,7 @@ void FullyConnected::prepareParams() {
}
}

execPtr = std::make_shared<ExecutorInnerProduct>(prim_desc);
execPtr = std::make_shared<DnnlExecutor>(prim_desc);
}
return execPtr;
};
Expand All @@ -404,26 +404,20 @@ void FullyConnected::prepareParams() {
execPtr = result.first;

if (execPtr) {
// no executor yet or shapes changed
if (!prevExecPtr || prevExecPtr->getSrcDesc() != execPtr->getSrcDesc()) {
auto oldMem = srcMemPtr->GetPrimitive();
// fast path: wanted is same with parent node output, typical is static shape with inner product
if (execPtr->getSrcDesc() == inDesc->getDnnlDesc()) {
primArgs[DNNL_ARG_SRC] = std::move(oldMem);
} else {
primArgs[DNNL_ARG_SRC] = dnnl::memory(execPtr->getSrcDesc(), oldMem.get_engine(), oldMem.get_data_handle());
}
if (execPtr->getSrcDesc()->isCompatible(*inDesc)) {
primArgs[DNNL_ARG_SRC] = srcMemPtr->GetPrimitive();
} else {
primArgs[DNNL_ARG_SRC] = dnnl::memory(execPtr->getDnnlSrcDesc(), engine, srcMemPtr->GetData());
}
if (!prevExecPtr || prevExecPtr->getDstDesc() != execPtr->getDstDesc()) {
auto oldMem = dstMemPtr->GetPrimitive();
if (execPtr->getDstDesc() == outDesc->getDnnlDesc()) {
primArgs[DNNL_ARG_DST] = std::move(oldMem);
} else {
primArgs[DNNL_ARG_DST] = dnnl::memory(execPtr->getDstDesc(), oldMem.get_engine(), oldMem.get_data_handle());
}

if (execPtr->getDstDesc()->isCompatible(*outDesc)) {
primArgs[DNNL_ARG_DST] = dstMemPtr->GetPrimitive();
} else {
primArgs[DNNL_ARG_DST] = dnnl::memory(execPtr->getDnnlDstDesc(), engine, dstMemPtr->GetData());
}
if (!prevExecPtr || prevExecPtr->getWeightDesc() != execPtr->getWeightDesc()) {
primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(DnnlExtensionUtils::makeDescriptor(execPtr->getWeightDesc()))->GetPrimitive();

if (!prevExecPtr || !execPtr->getWeightDesc()->isCompatible(*(prevExecPtr->getWeightDesc()))) {
primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc())->GetPrimitive();
}
// changed shapes may also cause the kernel type changed
selected_pd->setImplementationType(execPtr->getImplementationType());
Expand All @@ -438,9 +432,8 @@ void FullyConnected::prepareParams() {
primArgs[DNNL_ARG_BIAS] = biasMemPtr->GetPrimitive();
}

auto pd = execPtr->getPrimitiveDesc();
auto scratchpadMem = getScratchPadMem(pd);
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
auto schratchpadMem = getScratchPadMem(execPtr->getScratchPadDesc());
primArgs[DNNL_ARG_SCRATCHPAD] = schratchpadMem->GetPrimitive();
#ifdef CPU_DEBUG_CAPS
if (result.second == CacheEntryBase::LookUpStatus::Miss) {
DEBUG_LOG("verbose##", getName(), "##", pd->info(), "\n");
Expand Down Expand Up @@ -919,14 +912,6 @@ bool FullyConnected::canBeExecutedInConv1x1() const {
return retVal;
}

FullyConnected::ExecutorInnerProduct::ExecutorInnerProduct(const dnnl::inner_product_forward::primitive_desc& pd) {
execPrim = dnnl::inner_product_forward(pd);
}

FullyConnected::ExecutorConv1x1::ExecutorConv1x1(const dnnl::convolution_forward::primitive_desc& pd) {
execPrim = dnnl::convolution_forward(pd);
}

MemoryPtr FullyConnected::prepareWeightMemory(DnnlMemoryDescPtr weightDesc) {
if (!getParentEdgeAt(1)->getParent()->isConstant())
IE_THROW() << "Weight input is not const for node " << getName() << ".";
Expand Down
10 changes: 0 additions & 10 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,6 @@ class FullyConnected : public Node {
std::unordered_map<std::string, MemoryPtr> privateWeightCache;
dnnl::primitive_attr attr;

class ExecutorInnerProduct : public DnnlExecutor {
public:
ExecutorInnerProduct(const dnnl::inner_product_forward::primitive_desc& pd);
};

class ExecutorConv1x1 : public DnnlExecutor {
public:
ExecutorConv1x1(const dnnl::convolution_forward::primitive_desc& pd);
};

static dnnl::convolution_forward::primitive_desc
createDescriptorInternalForConv(DnnlMemoryDescCPtr inputDescPtr,
DnnlMemoryDescCPtr weightDescPtr,
Expand Down
Loading