diff --git a/src/plugins/intel_cpu/src/edge.cpp b/src/plugins/intel_cpu/src/edge.cpp index 291668e2ae4d18..090032f8e5ef28 100644 --- a/src/plugins/intel_cpu/src/edge.cpp +++ b/src/plugins/intel_cpu/src/edge.cpp @@ -51,19 +51,6 @@ bool Edge::isDropped() const { return not_in_parent && not_in_child; } -void Edge::drop() { - auto dropFrom = [] (const Edge* edge, std::vector &edges) { - edges.erase(std::remove_if(edges.begin(), edges.end(), - [&edge] (EdgeWeakPtr _edge) { - return _edge.lock().get() == edge; - }), - edges.end()); - }; - - dropFrom(this, getParent()->childEdges); - dropFrom(this, getChild()->parentEdges); -} - void Edge::collectConsumers(std::vector& result) const { if (!this->getChild()->getChildEdges().empty() && this->inPlace(LOOK_DOWN)) { if (auto peerChildSPD = this->getChild()->getSelectedPrimitiveDescriptor()) { diff --git a/src/plugins/intel_cpu/src/edge.h b/src/plugins/intel_cpu/src/edge.h index 31366b396830d9..4a188519d7227a 100644 --- a/src/plugins/intel_cpu/src/edge.h +++ b/src/plugins/intel_cpu/src/edge.h @@ -56,7 +56,6 @@ class Edge { void externalAllocate(WeightsSharing::Ptr weightsCache); void reuse(MemoryPtr ptr); void validate(); - void drop(); const std::shared_ptr getParent() const; const std::shared_ptr getChild() const; diff --git a/src/plugins/intel_cpu/src/graph.cpp b/src/plugins/intel_cpu/src/graph.cpp index 888e4094c5c7e3..d254f36a3efbac 100644 --- a/src/plugins/intel_cpu/src/graph.cpp +++ b/src/plugins/intel_cpu/src/graph.cpp @@ -260,8 +260,6 @@ void Graph::InitNodes() { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::InitNodes"); for (auto &node : graphNodes) { node->init(); - if (node->getConstantType() == Node::ConstantType::Unknown) - node->updateConstantType(); } } @@ -1435,7 +1433,9 @@ void Graph::CreateEdge(const NodePtr& parent, } void Graph::RemoveEdge(const EdgePtr& edge) { - edge->drop(); + edge->getParent()->removeChildEdge(edge); + edge->getChild()->removeParentEdge(edge); + graphEdges.erase(std::remove(graphEdges.begin(), graphEdges.end(), edge), graphEdges.end()); } @@ -1468,9 +1468,6 @@ void Graph::DropNode(const NodePtr &node) { const int outNum = c_edge->getOutputNum(); RemoveEdge(c_edge); CreateEdge(parent, child, inNum, outNum); - if (child->getConstantType() != Node::ConstantType::Unknown) { - child->updateConstantType(); - } } } } @@ -1574,7 +1571,8 @@ bool Graph::InsertNode(EdgePtr edge, NodePtr node, bool initNode) { " and ", edge->getChild()->getName(), "."); - edge->drop(); + edge->getParent()->removeChildEdge(edge); + edge->getChild()->removeParentEdge(edge); return InsertNode(edge->getParent(), edge->getChild(), node, iIndex, oIndex, initNode); } @@ -1582,7 +1580,6 @@ bool Graph::InsertNode(EdgePtr edge, NodePtr node, bool initNode) { bool Graph::InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPort, int childPort, bool initNode) { CreateEdge(parent, node, parentPort, 0); CreateEdge(node, child, 0, childPort); - node->updateConstantType(); AddNode(node); if (initNode) { diff --git a/src/plugins/intel_cpu/src/graph_optimizer.cpp b/src/plugins/intel_cpu/src/graph_optimizer.cpp index c5a55909572933..cf40705d92ddce 100644 --- a/src/plugins/intel_cpu/src/graph_optimizer.cpp +++ b/src/plugins/intel_cpu/src/graph_optimizer.cpp @@ -1745,8 +1745,9 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) } } - int peer_port = peerNode->getChildEdgeAt(childIdx)->getInputNum(); - peerNode->getChildEdgeAt(childIdx)->drop(); + auto peerEdge = peerNode->getChildEdgeAt(childIdx); + const int peer_port = peerEdge->getInputNum(); + graph.RemoveEdge(peerEdge); int childPort = 1; auto* mergedConvNode = dynamic_cast(mergedConv.get()); diff --git a/src/plugins/intel_cpu/src/node.cpp b/src/plugins/intel_cpu/src/node.cpp index ad189dd8994a1b..de686d0ff3185b 100644 --- a/src/plugins/intel_cpu/src/node.cpp +++ b/src/plugins/intel_cpu/src/node.cpp @@ -79,7 +79,7 @@ Node::Node(const std::shared_ptr& op, : selectedPrimitiveDescriptorIndex(-1), permanent(false), temporary(false), - constant(ConstantType::Unknown), + constant(ConstantType::NoConst), context(ctx), algorithm(Algorithm::Default), fusingPort(-1), @@ -185,7 +185,7 @@ Node::Node(const std::string& type, const std::string& name, const GraphContext: : selectedPrimitiveDescriptorIndex(-1), permanent(false), temporary(false), - constant(ConstantType::Unknown), + constant(ConstantType::NoConst), context(ctx), fusingPort(-1), engine(ctx->getEngine()), @@ -196,17 +196,13 @@ Node::Node(const std::string& type, const std::string& name, const GraphContext: // TODO [NM]: What about filling inDims and outDims? } -void Node::addEdge(const EdgeWeakPtr& edge) { - auto edgePtr = edge.lock(); - if (!edgePtr) - return; - auto parentPtr = edgePtr->getParent(); - auto childPtr = edgePtr->getChild(); - if (!parentPtr || !childPtr) - return; +void Node::addEdge(const EdgePtr& edge) { + auto parent = edge->getParent(); + auto child = edge->getChild(); + assert(parent && child); - parentPtr->addChildEdge(edge); - childPtr->addParentEdge(edge); + parent->addChildEdge(edge); + child->addParentEdge(edge); } void Node::remove() { @@ -214,7 +210,8 @@ void Node::remove() { for (auto& edge : edges) { auto edgePtr = edge.lock(); if (!edgePtr) continue; - edgePtr->drop(); + edgePtr->getParent()->removeChildEdge(edgePtr); + edgePtr->getChild()->removeParentEdge(edgePtr); } }; @@ -954,26 +951,26 @@ Node::ConstantType Node::getConstantType() const { } bool Node::isConstant() { - if (getConstantType() == ConstantType::Unknown) - updateConstantType(); return getConstantType() == ConstantType::Const; } void Node::updateConstantType() { - if (constant != ConstantType::StrictNoConst) { - bool isConst = true; - for (const auto& parentEdge : getParentEdges()) { - isConst &= parentEdge.lock()->getParent()->isConstant(); - } - constant = isConst ? ConstantType::Const : ConstantType::NoConst; + if (constant == ConstantType::StrictNoConst) + return; + + bool isConst = true; + for (const auto& parentEdge : getParentEdges()) { + isConst &= parentEdge.lock()->getParent()->isConstant(); } + const auto prevConstantType = constant; + constant = isConst ? ConstantType::Const : ConstantType::NoConst; + if (constant == prevConstantType) + return; // state has not changed, no reason to continue + for (const auto& childEdge : getChildEdges()) { const auto childNode = childEdge.lock()->getChild(); - const auto childConstType = childNode->getConstantType(); - if (!one_of(childConstType, ConstantType::Unknown, ConstantType::StrictNoConst, constant)) { - childNode->updateConstantType(); - } + childNode->updateConstantType(); } } diff --git a/src/plugins/intel_cpu/src/node.h b/src/plugins/intel_cpu/src/node.h index 20a7e725d01557..7601d09c4cd0d1 100644 --- a/src/plugins/intel_cpu/src/node.h +++ b/src/plugins/intel_cpu/src/node.h @@ -168,19 +168,29 @@ class Node { // @todo the method is used when graph is "preconstructed" before creation of the actual graph object // remove, as soon edges are added via Graph interface exclusively - static void addEdge(const EdgeWeakPtr& edge); + static void addEdge(const EdgePtr& edge); virtual void cleanup(); void remove(); - void addParentEdge(const EdgeWeakPtr& edge) { + void addParentEdge(const EdgePtr& edge) { parentEdges.push_back(edge); + updateConstantType(); } - void addChildEdge(const EdgeWeakPtr& edge) { + void addChildEdge(const EdgePtr& edge) { childEdges.push_back(edge); } + void removeParentEdge(const EdgePtr edge) { + removeEdge(edge, parentEdges); + updateConstantType(); + } + + void removeChildEdge(const EdgePtr edge) { + removeEdge(edge, childEdges); + } + const std::vector &getParentEdges() const noexcept { return parentEdges; } @@ -214,7 +224,6 @@ class Node { } enum class ConstantType { - Unknown, // Unknown ConstantType is used before the constancy determination procedure run Const, // Node is placed in a constant subgraph NoConst, // Node is placed in a non-constant subgraph StrictNoConst, // Node produces non-constant subgraph: this type can't be changed and it does not depend on the parent nodes' ConstantType. @@ -614,7 +623,7 @@ class Node { NoInPlace }; mutable InPlaceType inplace = InPlaceType::Unknown; - ConstantType constant = ConstantType::Unknown; + ConstantType constant = ConstantType::NoConst; std::vector internalBlobs; std::vector internalBlobMemory; std::vector supportedPrimitiveDescriptors; @@ -708,6 +717,16 @@ class Node { std::shared_ptr shapeInference; private: + static void removeEdge(const EdgePtr edge, std::vector &edges) { + edges.erase(std::remove_if(edges.begin(), edges.end(), + [&edge] (EdgeWeakPtr _edge) { + return _edge.lock() == edge; + }), + edges.end()); + } + + bool isEdgesEmpty(const std::vector& edges) const; + std::vector parentEdges; std::vector childEdges; @@ -730,8 +749,6 @@ class Node { MemoryPtr scratchpadMem; - bool isEdgesEmpty(const std::vector& edges) const; - // Hold output scales std::vector DQScales; // we cannot rely on per-NUMA weightCache for caching weights because: diff --git a/src/plugins/intel_cpu/src/nodes/input.cpp b/src/plugins/intel_cpu/src/nodes/input.cpp index a6ebe7b27790b9..518350e3d47627 100644 --- a/src/plugins/intel_cpu/src/nodes/input.cpp +++ b/src/plugins/intel_cpu/src/nodes/input.cpp @@ -218,22 +218,21 @@ jit_has_subnormals_base::fn_t jit_has_subnormals_function() { Input::Input(const std::shared_ptr& op, const GraphContext::CPtr context) : Node(op, context, PassThroughShapeInferFactory()) { if (!one_of(op->get_type_info(), - op::v0::Parameter::get_type_info_static(), - op::v0::Constant::get_type_info_static(), - op::v0::Result::get_type_info_static(), - op::v3::ReadValue::get_type_info_static(), - op::v6::ReadValue::get_type_info_static())) + op::v0::Parameter::get_type_info_static(), + op::v0::Constant::get_type_info_static(), + op::v0::Result::get_type_info_static(), + op::v3::ReadValue::get_type_info_static(), + op::v6::ReadValue::get_type_info_static())) OPENVINO_THROW_NOT_IMPLEMENTED("CPU Input node doesn't support ngraph operation ", op->get_type_name(), " with name ", op->get_friendly_name()); - - constant = ConstantType::NoConst; - constOp = ov::as_type_ptr(op); if (constOp) { constant = ConstantType::Const; cloneBlobIfRequired(); + } else { + constant = ConstantType::StrictNoConst; } }