Skip to content

Commit

Permalink
Solve unsupported precision issue in tranfromation rather than init_edge
Browse files Browse the repository at this point in the history
  • Loading branch information
riverlijunjie committed Sep 13, 2023
1 parent eee324d commit b418836
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 103 deletions.
81 changes: 0 additions & 81 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,28 +504,6 @@ void Graph::InitEdges() {
numberOfEdges--;
};

// In case of edge's child is 'Convert' op and the edge's input/output has different precision, if set convert's input
// precision same with edge parent precision, it will avoid additional new convert op to be added.
for (ptrdiff_t i = 0; i < numberOfEdges; i++) {
auto edge = graphEdges[i];
if ((edge->getChild()->getType() == Type::Convert) &&
!DnnlExtensionUtils::isSupportedPrecision(edge->getOutputDesc().getPrecision()) &&
edge->getInputDesc().getPrecision() != edge->getOutputDesc().getPrecision()) {
auto convert = edge->getChild();
const auto& inDesc = edge->getInputDesc();
const auto& outDesc = convert->getChildEdgeAt(0)->getInputDesc();
std::string convertName = convert->getName();
DEBUG_LOG("replace convert node: ", convertName);
auto convertNode = std::make_shared<node::Convert>(inDesc.getShape(),
inDesc.getPrecision(),
outDesc.getPrecision(),
convertName,
context);
convertNode->setDescs(inDesc, outDesc);
ReplaceNode(edge, convert, convertNode, true);
}
}

for (ptrdiff_t i = 0; i < numberOfEdges; i++) {
auto edge = graphEdges[i];
auto reorderStatus = graphEdges[i]->needReorder();
Expand Down Expand Up @@ -1682,65 +1660,6 @@ bool Graph::InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPo
return true;
}

bool Graph::ReplaceNode(EdgePtr parent, NodePtr oldNode, NodePtr newNode, bool initNode) {
auto remove_graph_edge = [&](EdgePtr _edge) {
for (auto it = graphEdges.begin(); it != graphEdges.end();) {
if (*it == _edge) {
it = graphEdges.erase(it);
} else {
++it;
}
}
};
auto remove_edge = [&](EdgePtr _edge, std::vector<ov::intel_cpu::EdgeWeakPtr>& edges) {
for (auto it = edges.begin(); it != edges.end();) {
if (static_cast<EdgePtr>(*it) == _edge) {
it = edges.erase(it);
} else {
++it;
}
}
};

EdgePtr beforeNode(new Edge(parent->getParent(), newNode, parent->getInputNum(), parent->getOutputNum()));
beforeNode->getChild()->parentEdges.push_back(beforeNode);
parent->getParent()->childEdges.push_back(beforeNode);
remove_edge(parent, parent->getParent()->childEdges);

// There maybe multiple child edges
for (size_t i = 0; i < oldNode->childEdges.size(); i++) {
auto edge = oldNode->getChildEdgeAt(i);
EdgePtr afterNode(new Edge(newNode, edge->getChild(), edge->getInputNum(), edge->getOutputNum()));
afterNode->getParent()->childEdges.push_back(afterNode);
edge->getChild()->parentEdges.push_back(afterNode);
remove_edge(edge, edge->getChild()->parentEdges);
remove_graph_edge(edge);
graphEdges.push_back(afterNode);
}
remove_graph_edge(parent);
graphEdges.push_back(beforeNode);

if (initNode) {
newNode->getSupportedDescriptors();
newNode->initSupportedPrimitiveDescriptors();
newNode->filterSupportedPrimitiveDescriptors();
newNode->selectOptimalPrimitiveDescriptor();
resolveInPlaceDirection(newNode);
newNode->initOptimalPrimitiveDescriptor();
}
graphNodes.push_back(newNode);
oldNode->remove();
for (auto it = graphNodes.begin(); it != graphNodes.end();) {
if (*it == oldNode) {
it = graphNodes.erase(it);
} else {
++it;
}
}

return true;
}

// Apply inference precision configuration
void Graph::EnforceInferencePrecision() {
CPU_DEBUG_CAP_ENABLE(static EnforceInferPrcDebug inferPrecDebug);
Expand Down
2 changes: 0 additions & 2 deletions src/plugins/intel_cpu/src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ class Graph {
*/
bool InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPort, int childPort, bool initNode = false);

bool ReplaceNode(EdgePtr edge, NodePtr oldNode, NodePtr newNode, bool initNode);

std::shared_ptr<ov::Model> dump() const;

void ResetInferCount() { infer_count = 0; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,32 +128,44 @@ namespace intel_cpu {

using const_node_ptr = const std::shared_ptr<const ov::Node>;

bool Transformations::fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
bool Transformations::fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node,
const precisions_map& precisions) {
auto convert = ov::as_type_ptr<ov::opset10::Convert>(node);
if (!convert)
return false;

// For "Parameter->Convert" , set supported precision to convert's input tensor to avoid introducing unsupported
// precision into dnnl level.
auto parameter_node = ov::as_type_ptr<ov::op::v0::Parameter>(convert->input_value(0).get_node_shared_ptr());
if (parameter_node) {
const auto& prec = node->get_input_element_type(0);
auto item = precisions.find(prec);
if (item != precisions.end()) {
convert->input_value(0).get_tensor().set_element_type(item->second);
}
}

const auto& from = node->get_output_element_type(0);
auto it = precisions.find(from);
if (it == precisions.end())
return false;
const auto& to = it->second;
if (auto convert = ov::as_type_ptr<ov::opset10::Convert>(node)) {
// For Convert node, converting precision from floating point to boolean will lead to mathematical
// error, because here the output precision boolean is replaced by u8. E.g. floating point value 0.01
// is converted to be 1 for boolean, but 0 for u8. Thus an Abs and Ceil node should be added before the
// Convert node for this scenario.
if (convert->input(0).get_element_type().is_real() &&
convert->get_convert_element_type() == ngraph::element::boolean && to.is_integral_number()) {
auto abs = std::make_shared<ov::opset10::Abs>(convert->input_value(0).get_node_shared_ptr());
auto ceil = std::make_shared<ov::opset10::Ceiling>(abs);
auto new_convert = std::make_shared<ov::opset10::Convert>(ceil, to);
new_convert->set_friendly_name(convert->get_friendly_name());
ov::copy_runtime_info(convert, {abs, ceil, new_convert});
ov::replace_node(convert, new_convert);
return true;
} else {
convert->set_convert_element_type(to);
return true;
}
// For Convert node, converting precision from floating point to boolean will lead to mathematical
// error, because here the output precision boolean is replaced by u8. E.g. floating point value 0.01
// is converted to be 1 for boolean, but 0 for u8. Thus an Abs and Ceil node should be added before the
// Convert node for this scenario.
if (convert->input(0).get_element_type().is_real() &&
convert->get_convert_element_type() == ngraph::element::boolean && to.is_integral_number()) {
auto abs = std::make_shared<ov::opset10::Abs>(convert->input_value(0).get_node_shared_ptr());
auto ceil = std::make_shared<ov::opset10::Ceiling>(abs);
auto new_convert = std::make_shared<ov::opset10::Convert>(ceil, to);
new_convert->set_friendly_name(convert->get_friendly_name());
ov::copy_runtime_info(convert, {abs, ceil, new_convert});
ov::replace_node(convert, new_convert);
} else {
convert->set_convert_element_type(to);
}
return false;
return true;
}

void Transformations::UpToLpt() {
Expand Down

0 comments on commit b418836

Please sign in to comment.