Skip to content

Commit

Permalink
[CPU] Update Node constant type when adding or removing parent edges (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky authored Jan 8, 2024
1 parent c0d564a commit 1935bba
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 64 deletions.
13 changes: 0 additions & 13 deletions src/plugins/intel_cpu/src/edge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,6 @@ bool Edge::isDropped() const {
return not_in_parent && not_in_child;
}

void Edge::drop() {
auto dropFrom = [] (const Edge* edge, std::vector<EdgeWeakPtr> &edges) {
edges.erase(std::remove_if(edges.begin(), edges.end(),
[&edge] (EdgeWeakPtr _edge) {
return _edge.lock().get() == edge;
}),
edges.end());
};

dropFrom(this, getParent()->childEdges);
dropFrom(this, getChild()->parentEdges);
}

void Edge::collectConsumers(std::vector<NodePtr>& result) const {
if (!this->getChild()->getChildEdges().empty() && this->inPlace(LOOK_DOWN)) {
if (auto peerChildSPD = this->getChild()->getSelectedPrimitiveDescriptor()) {
Expand Down
1 change: 0 additions & 1 deletion src/plugins/intel_cpu/src/edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class Edge {
void externalAllocate(WeightsSharing::Ptr weightsCache);
void reuse(MemoryPtr ptr);
void validate();
void drop();

const std::shared_ptr<Node> getParent() const;
const std::shared_ptr<Node> getChild() const;
Expand Down
13 changes: 5 additions & 8 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,6 @@ void Graph::InitNodes() {
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::InitNodes");
for (auto &node : graphNodes) {
node->init();
if (node->getConstantType() == Node::ConstantType::Unknown)
node->updateConstantType();
}
}

Expand Down Expand Up @@ -1435,7 +1433,9 @@ void Graph::CreateEdge(const NodePtr& parent,
}

void Graph::RemoveEdge(const EdgePtr& edge) {
edge->drop();
edge->getParent()->removeChildEdge(edge);
edge->getChild()->removeParentEdge(edge);

graphEdges.erase(std::remove(graphEdges.begin(), graphEdges.end(), edge), graphEdges.end());
}

Expand Down Expand Up @@ -1468,9 +1468,6 @@ void Graph::DropNode(const NodePtr &node) {
const int outNum = c_edge->getOutputNum();
RemoveEdge(c_edge);
CreateEdge(parent, child, inNum, outNum);
if (child->getConstantType() != Node::ConstantType::Unknown) {
child->updateConstantType();
}
}
}
}
Expand Down Expand Up @@ -1574,15 +1571,15 @@ bool Graph::InsertNode(EdgePtr edge, NodePtr node, bool initNode) {
" and ",
edge->getChild()->getName(),
".");
edge->drop();
edge->getParent()->removeChildEdge(edge);
edge->getChild()->removeParentEdge(edge);

return InsertNode(edge->getParent(), edge->getChild(), node, iIndex, oIndex, initNode);
}

bool Graph::InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPort, int childPort, bool initNode) {
CreateEdge(parent, node, parentPort, 0);
CreateEdge(node, child, 0, childPort);
node->updateConstantType();
AddNode(node);

if (initNode) {
Expand Down
5 changes: 3 additions & 2 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1745,8 +1745,9 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph)
}
}

int peer_port = peerNode->getChildEdgeAt(childIdx)->getInputNum();
peerNode->getChildEdgeAt(childIdx)->drop();
auto peerEdge = peerNode->getChildEdgeAt(childIdx);
const int peer_port = peerEdge->getInputNum();
graph.RemoveEdge(peerEdge);

int childPort = 1;
auto* mergedConvNode = dynamic_cast<Convolution*>(mergedConv.get());
Expand Down
47 changes: 22 additions & 25 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Node::Node(const std::shared_ptr<ov::Node>& op,
: selectedPrimitiveDescriptorIndex(-1),
permanent(false),
temporary(false),
constant(ConstantType::Unknown),
constant(ConstantType::NoConst),
context(ctx),
algorithm(Algorithm::Default),
fusingPort(-1),
Expand Down Expand Up @@ -185,7 +185,7 @@ Node::Node(const std::string& type, const std::string& name, const GraphContext:
: selectedPrimitiveDescriptorIndex(-1),
permanent(false),
temporary(false),
constant(ConstantType::Unknown),
constant(ConstantType::NoConst),
context(ctx),
fusingPort(-1),
engine(ctx->getEngine()),
Expand All @@ -196,25 +196,22 @@ Node::Node(const std::string& type, const std::string& name, const GraphContext:
// TODO [NM]: What about filling inDims and outDims?
}

void Node::addEdge(const EdgeWeakPtr& edge) {
auto edgePtr = edge.lock();
if (!edgePtr)
return;
auto parentPtr = edgePtr->getParent();
auto childPtr = edgePtr->getChild();
if (!parentPtr || !childPtr)
return;
void Node::addEdge(const EdgePtr& edge) {
auto parent = edge->getParent();
auto child = edge->getChild();
assert(parent && child);

parentPtr->addChildEdge(edge);
childPtr->addParentEdge(edge);
parent->addChildEdge(edge);
child->addParentEdge(edge);
}

void Node::remove() {
auto drop = [](std::vector<EdgeWeakPtr> edges){
for (auto& edge : edges) {
auto edgePtr = edge.lock();
if (!edgePtr) continue;
edgePtr->drop();
edgePtr->getParent()->removeChildEdge(edgePtr);
edgePtr->getChild()->removeParentEdge(edgePtr);
}
};

Expand Down Expand Up @@ -954,26 +951,26 @@ Node::ConstantType Node::getConstantType() const {
}

bool Node::isConstant() {
if (getConstantType() == ConstantType::Unknown)
updateConstantType();
return getConstantType() == ConstantType::Const;
}

void Node::updateConstantType() {
if (constant != ConstantType::StrictNoConst) {
bool isConst = true;
for (const auto& parentEdge : getParentEdges()) {
isConst &= parentEdge.lock()->getParent()->isConstant();
}
constant = isConst ? ConstantType::Const : ConstantType::NoConst;
if (constant == ConstantType::StrictNoConst)
return;

bool isConst = true;
for (const auto& parentEdge : getParentEdges()) {
isConst &= parentEdge.lock()->getParent()->isConstant();
}

const auto prevConstantType = constant;
constant = isConst ? ConstantType::Const : ConstantType::NoConst;
if (constant == prevConstantType)
return; // state has not changed, no reason to continue

for (const auto& childEdge : getChildEdges()) {
const auto childNode = childEdge.lock()->getChild();
const auto childConstType = childNode->getConstantType();
if (!one_of(childConstType, ConstantType::Unknown, ConstantType::StrictNoConst, constant)) {
childNode->updateConstantType();
}
childNode->updateConstantType();
}
}

Expand Down
31 changes: 24 additions & 7 deletions src/plugins/intel_cpu/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,29 @@ class Node {

// @todo the method is used when graph is "preconstructed" before creation of the actual graph object
// remove, as soon edges are added via Graph interface exclusively
static void addEdge(const EdgeWeakPtr& edge);
static void addEdge(const EdgePtr& edge);

virtual void cleanup();
void remove();

void addParentEdge(const EdgeWeakPtr& edge) {
void addParentEdge(const EdgePtr& edge) {
parentEdges.push_back(edge);
updateConstantType();
}

void addChildEdge(const EdgeWeakPtr& edge) {
void addChildEdge(const EdgePtr& edge) {
childEdges.push_back(edge);
}

void removeParentEdge(const EdgePtr edge) {
removeEdge(edge, parentEdges);
updateConstantType();
}

void removeChildEdge(const EdgePtr edge) {
removeEdge(edge, childEdges);
}

const std::vector<EdgeWeakPtr> &getParentEdges() const noexcept {
return parentEdges;
}
Expand Down Expand Up @@ -214,7 +224,6 @@ class Node {
}

enum class ConstantType {
Unknown, // Unknown ConstantType is used before the constancy determination procedure run
Const, // Node is placed in a constant subgraph
NoConst, // Node is placed in a non-constant subgraph
StrictNoConst, // Node produces non-constant subgraph: this type can't be changed and it does not depend on the parent nodes' ConstantType.
Expand Down Expand Up @@ -614,7 +623,7 @@ class Node {
NoInPlace
};
mutable InPlaceType inplace = InPlaceType::Unknown;
ConstantType constant = ConstantType::Unknown;
ConstantType constant = ConstantType::NoConst;
std::vector<MemoryPtr> internalBlobs;
std::vector<MemoryPtr> internalBlobMemory;
std::vector<NodeDesc> supportedPrimitiveDescriptors;
Expand Down Expand Up @@ -708,6 +717,16 @@ class Node {
std::shared_ptr<IShapeInfer> shapeInference;

private:
static void removeEdge(const EdgePtr edge, std::vector<EdgeWeakPtr> &edges) {
edges.erase(std::remove_if(edges.begin(), edges.end(),
[&edge] (EdgeWeakPtr _edge) {
return _edge.lock() == edge;
}),
edges.end());
}

bool isEdgesEmpty(const std::vector<EdgeWeakPtr>& edges) const;

std::vector<EdgeWeakPtr> parentEdges;
std::vector<EdgeWeakPtr> childEdges;

Expand All @@ -730,8 +749,6 @@ class Node {

MemoryPtr scratchpadMem;

bool isEdgesEmpty(const std::vector<EdgeWeakPtr>& edges) const;

// Hold output scales
std::vector<float> DQScales;
// we cannot rely on per-NUMA weightCache for caching weights because:
Expand Down
15 changes: 7 additions & 8 deletions src/plugins/intel_cpu/src/nodes/input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,22 +218,21 @@ jit_has_subnormals_base::fn_t jit_has_subnormals_function() {
Input::Input(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
: Node(op, context, PassThroughShapeInferFactory()) {
if (!one_of(op->get_type_info(),
op::v0::Parameter::get_type_info_static(),
op::v0::Constant::get_type_info_static(),
op::v0::Result::get_type_info_static(),
op::v3::ReadValue::get_type_info_static(),
op::v6::ReadValue::get_type_info_static()))
op::v0::Parameter::get_type_info_static(),
op::v0::Constant::get_type_info_static(),
op::v0::Result::get_type_info_static(),
op::v3::ReadValue::get_type_info_static(),
op::v6::ReadValue::get_type_info_static()))
OPENVINO_THROW_NOT_IMPLEMENTED("CPU Input node doesn't support ngraph operation ",
op->get_type_name(),
" with name ",
op->get_friendly_name());

constant = ConstantType::NoConst;

constOp = ov::as_type_ptr<op::v0::Constant>(op);
if (constOp) {
constant = ConstantType::Const;
cloneBlobIfRequired();
} else {
constant = ConstantType::StrictNoConst;
}
}

Expand Down

0 comments on commit 1935bba

Please sign in to comment.