Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Drop redundant MemoryOutput nodes #27189

Merged
merged 17 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
MatchSdpaKvCache(graph);
graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "DropRedundantMemoryOutput");
DropRedundantMemoryOutput(graph);
graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveDroppedEdges");
graph.RemoveDroppedEdges();
}
Expand Down Expand Up @@ -3186,5 +3190,117 @@ void GraphOptimizer::MatchSdpaKvCache(Graph &graph) {
}
}

void GraphOptimizer::DropRedundantMemoryOutput(Graph &graph) {
// When we have a MemoryInput->MemoryOutput pair, that means that the state is immediately populated with the init
// subgraph values when the init subgraph exists. In all the other cases the state is simply a read only object.
// We can optimize such a case removing the MemoryOutput node and transferring the state values update
// responsibility to a special type of the MemoryInput node - MemoryInputSingle
auto& graphNodes = graph.GetNodes();

auto isSuitableMemInput = [](const NodePtr& node) -> bool {
if (Type::MemoryInput != node->getType()) {
return false;
}

CPU_GRAPH_OPTIMIZER_SCOPE(DropRedundantMemoryOutput_isSuitableMemInput);

auto memInputBase = std::dynamic_pointer_cast<MemoryNode>(node);
OPENVINO_ASSERT(memInputBase,
"Unexpectedly wrong dynamic type of node: ",
node->getName(),
" of type: ",
node->getTypeStr());

auto id = memInputBase->getId();

NodePtr MemoryOutput = nullptr;
auto&& childEdges = node->getChildEdgesAtPort(0);
for (auto&& item : childEdges) {
auto childNode = item->getChild();

if (Type::MemoryOutput == childNode->getType()) {
auto memOutputBase = std::dynamic_pointer_cast<MemoryNode>(childNode);
OPENVINO_ASSERT(memInputBase,
"Unexpectedly wrong dynamic type of node: ",
node->getName(),
" of type: ",
node->getTypeStr());

if (memOutputBase->getId() != id) {
return false; // an Assign node from different Variable is attached
}

if (MemoryOutput && MemoryOutput != childNode) {
//only one child MemoryOutput is expected
return false;
}
MemoryOutput = childNode;
}
}
return nullptr != MemoryOutput;
};

for (size_t i = 0; i < graphNodes.size(); i++) {
auto node = graphNodes[i];
if (!isSuitableMemInput(node)) {
continue;
}

CPU_GRAPH_OPTIMIZER_SCOPE(DropRedundantMemoryOutput_Node);

auto memInputNode = std::dynamic_pointer_cast<node::MemoryInputBase>(node);
OPENVINO_ASSERT(memInputNode, "MemoryInput node ", node->getName(), " has unexpected dynamic type");

ov::optional<Shape> inputShape;
ov::optional<ov::element::Type> inputPrc;

if (!node->getParentEdges().empty()) {
inputShape = ov::optional<Shape>(node->getInputShapeAtPort(0));
inputPrc = ov::optional<ov::element::Type>(node->getOriginalInputPrecisionAtPort(0));
}

//search for the MemoryOutputNode
NodePtr memoryOutputNode;
for (auto&& edge : node->getChildEdgesAtPort(0)) {
auto child = edge->getChild();
if (Type::MemoryOutput == child->getType()) {
memoryOutputNode = child;
break;
}
}
OPENVINO_ASSERT(memoryOutputNode, "Corresponding MemoryOutput has not been found");

graph.RemoveEdge(memoryOutputNode->getParentEdgeAt(0));
// there are no output edges from MemoryOutput nodes

// now replace the existing MemoryInput with a special type that works without the corresponding MemoryOutput
auto memInputSingle = std::make_shared<MemoryInputSingle>(memInputNode->getId(),
memInputNode->getName(),
memInputNode->getTypeStr(),
memInputNode->getOutputShapeAtPort(0),
memInputNode->getOriginalOutputPrecisionAtPort(0),
graph.getGraphContext(),
inputShape,
inputPrc);

graph.AddNode(memInputSingle);

if (!memInputNode->getParentEdges().empty()) {
auto parentEdge = memInputNode->getParentEdgeAt(0);
auto parent = parentEdge->getParent();
const auto inputNum = parentEdge->getInputNum();
graph.RemoveEdge(parentEdge);
graph.CreateEdge(parent, memInputSingle, inputNum, 0);
}

for (auto&& edge : memInputNode->getChildEdgesAtPort(0)) {
auto child = edge->getChild();
const auto outputNum = edge->getOutputNum();
graph.RemoveEdge(edge);
graph.CreateEdge(memInputSingle, child, 0, outputNum);
}
}
}

} // namespace intel_cpu
} // namespace ov
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/graph_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class GraphOptimizer {
void RemoveMemoryInputConvert(Graph &graph);
void RemoveConvertMemoryOutput(Graph &graph);
void MatchSdpaKvCache(Graph &graph);
void DropRedundantMemoryOutput(Graph &graph);

bool canBeInplaced(const NodePtr& parentNode, const NodePtr& childNode);
// Method checks that after the sequential execution of Transpose and Reorder nodes,
Expand Down
43 changes: 43 additions & 0 deletions src/plugins/intel_cpu/src/memory_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,49 @@ MemoryPtr VariableStateDoubleBuffer::internal_state_mem() const {
return prime_mem();
}

VariableStateSingleBuffer::VariableStateSingleBuffer(const std::string& name,
const MemoryPtr& external_buffer,
const MemoryDescPtr& external_desc)
: VariableStateBase(name, external_desc) {
OPENVINO_ASSERT(external_buffer);
m_internal_mem = external_buffer;
m_internal_desc = m_internal_mem->getDescPtr();
auto&& shape = m_internal_desc->getShape();

if (shape.isStatic()) {
m_internal_mem->nullify();
} else {
// in the case of the original desc has dynamic shape we create an empty tensor
auto new_desc = to_static(m_internal_desc);
m_internal_mem->redefineDesc(new_desc);
}
}
MemoryPtr VariableStateSingleBuffer::input_mem() {
return m_internal_mem;
}
MemoryPtr VariableStateSingleBuffer::output_mem() {
return m_internal_mem;
}
MemoryDescPtr VariableStateSingleBuffer::internal_desc() const {
return m_internal_desc;
}

void VariableStateSingleBuffer::reset_impl() {
auto new_desc = to_static(m_internal_desc);
if (m_internal_mem) {
m_internal_mem->redefineDesc(new_desc);
m_internal_mem->nullify();
}
}

MemoryPtr VariableStateSingleBuffer::internal_state_mem() const {
return m_internal_mem;
}

void VariableStateSingleBuffer::commit_impl() {
// nothing to do
}

VariableStateKVcache::VariableStateKVcache(
const std::string& name,
const MemoryDescPtr& external_desc,
Expand Down
21 changes: 21 additions & 0 deletions src/plugins/intel_cpu/src/memory_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,27 @@ class VariableStateDoubleBuffer : public VariableStateBase {
size_t buffer_num = 0;
};

class VariableStateSingleBuffer : public VariableStateBase {
public:
VariableStateSingleBuffer(const std::string& name,
const MemoryPtr& external_buffer,
const MemoryDescPtr& external_desc);

MemoryPtr input_mem() override;
MemoryPtr output_mem() override;
MemoryDescPtr internal_desc() const override;

private:
void reset_impl() override;
void commit_impl() override;

MemoryPtr internal_state_mem() const override;

private:
MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor
MemoryPtr m_internal_mem;
};

class VariableStateKVcache : public VariableStateBase {
public:
VariableStateKVcache(const std::string& name,
Expand Down
119 changes: 114 additions & 5 deletions src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,16 @@ bool MemoryInputBase::isSupportedOperation(const std::shared_ptr<const ov::Node>
}

MemoryInputBase::MemoryInputBase(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr ctx)
: Input(op, ctx), MemoryStateNode(op) {
: Input(op, ctx),
MemoryStateNode(op) {
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage);
}
if (created()) {
context->getMemoryStatesRegister()->registerInput(this);
}
executeHook = &MemoryInputBase::assignState;
}

MemoryInputBase::MemoryInputBase(const std::string id,
Expand All @@ -394,8 +396,10 @@ MemoryInputBase::MemoryInputBase(const std::string id,
const ov::element::Type& output_prc,
const GraphContext::CPtr context,
const ov::optional<Shape>& input_shape,
const ov::optional<ov::element::Type>& input_prc) :
Input(output_shape, output_prc, name, type, context), MemoryStateNode(id) {
const ov::optional<ov::element::Type>& input_prc,
MemoryInputBase::mode mode)
: Input(output_shape, output_prc, name, type, context),
MemoryStateNode(id) {
outputShapes.emplace_back(output_shape);
addOriginalOutputPrecision(output_prc);
if (input_shape) {
Expand All @@ -411,6 +415,17 @@ MemoryInputBase::MemoryInputBase(const std::string id,
if (created()) {
context->getMemoryStatesRegister()->registerInput(this);
}

// this important to prevent identifying it as a const when it's on a const path
constant = ConstantType::StrictNoConst;

if (mode::read_value_assign == mode) {
executeHook = &MemoryInputBase::assignState;
} else if (mode::single_read_value == mode) {
executeHook = &MemoryInputBase::bypassAssignState;
} else {
THROW_CPU_NODE_ERR("Unexpected MemoryInput mode");
}
}

MemoryInputBase::~MemoryInputBase() {
Expand Down Expand Up @@ -513,15 +528,26 @@ void MemoryInputBase::assignState(MemStatePtr newState) {
}

void MemoryInputBase::execute(dnnl::stream strm) {
getOutputNode().assignState(getAssignedState());
assert(executeHook && "executeHook is not initialized!");
(this->*executeHook)();
runStatic(strm);
}

void MemoryInputBase::executeDynamicImpl(dnnl::stream strm) {
getOutputNode().assignState(getAssignedState());
assert(executeHook && "executeHook is not initialized!");
(this->*executeHook)();
runDynamic(strm);
}

void MemoryInputBase::assignState() {
getOutputNode().assignState(getAssignedState());
}

void MemoryInputBase::bypassAssignState() {
// nothing to do
return;
}

bool MemoryInput::needInitGraphProcessing() const {
return !getParentEdges().empty() && getAssignedState()->is_reset_state();
}
Expand Down Expand Up @@ -828,6 +854,89 @@ void MemoryInputSDPA::resolveInPlaceEdges(Edge::LOOK look) {
}
}

MemoryInputSingle::MemoryInputSingle(const std::string id,
const std::string& name,
const std::string& type,
const Shape& output_shape,
const ov::element::Type& output_prc,
const GraphContext::CPtr context,
const ov::optional<Shape>& input_shape,
const ov::optional<ov::element::Type>& input_prc)
: MemoryInput(id,
name,
type,
output_shape,
output_prc,
context,
input_shape,
input_prc,
MemoryInputBase::mode::single_read_value) {}

MemStatePtr MemoryInputSingle::makeState() const {
// assume ov::Tensor is always dense
auto original_desc =
std::make_shared<CpuBlockedMemoryDesc>(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0));

auto mem_desc = getBaseMemDescAtOutputPort(0);
const auto& eng = getEngine();

auto state_name = getId();

// Remove suffix with pair ID. Internal information.
auto suffix_idx = state_name.find("/id=");
if (suffix_idx != std::string::npos) {
state_name = state_name.substr(0, suffix_idx);
}

return std::make_shared<VariableStateSingleBuffer>(state_name,
std::make_shared<Memory>(eng, mem_desc),
original_desc);
}

void MemoryInputSingle::runStatic(dnnl::stream strm) {
MemoryInput::runStatic(strm);
if (needInitGraphProcessing()) {
// since there is no corresponding MemoryOutput node, we need to update the state here
auto result = getDstMemoryAtPort(0); // only one output port
auto stateMem = getAssignedState()->output_mem();
CPU_NODE_ASSERT(stateMem, " state memory has nullptr");
if (result->getData() != stateMem->getData()) {
stateMem->load(*result);
}
}
getAssignedState()->commit(); // since we don't use MemoryOutput, commit must be called to change the reset state
}

void MemoryInputSingle::runDynamic(dnnl::stream strm) {
MemoryInput::runDynamic(strm);
if (needInitGraphProcessing()) {
// since there is no corresponding MemoryOutput node, we need to update the state here
auto result = getDstMemoryAtPort(0); // only one output port
auto state = getAssignedState();
auto stateMem = state->output_mem();
CPU_NODE_ASSERT(stateMem, " state memory has nullptr");

const auto& newShape = result->getShape();
const auto& stateShape = stateMem->getShape();

if (stateShape.isDynamic() || stateShape.getStaticDims() != newShape.getStaticDims()) {
auto extMemDesc = state->internal_desc();
auto newExternDesc = extMemDesc->cloneWithNewDims(newShape.getStaticDims());
stateMem->redefineDesc(newExternDesc);
}

if (result->getData() != stateMem->getData()) {
stateMem->load(*result);
}
}
getAssignedState()->commit(); // since we don't use MemoryOutput, commit must be called to change the reset state
}

bool MemoryInputSingle::isSupportedOperation(const std::shared_ptr<const ov::Node>& op,
std::string& errorMessage) noexcept {
return MemoryInput::isSupportedOperation(op, errorMessage);
}

} // namespace node
} // namespace intel_cpu
} // namespace ov
Loading
Loading