Skip to content

Commit

Permalink
[LPT] IR comparison #2 (has to be squashed)
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Jun 7, 2021
1 parent ed331c2 commit c4cdccc
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -681,24 +681,24 @@ void Comparator::compare_inputs(ngraph::Node* node1, ngraph::Node* node2, std::o
auto const2 = ngraph::as_type_ptr<Constant>(node2->get_input_node_shared_ptr(i));
if (const1 && const2 && !equal_value(const1, const2)) {
err_log << "Different Constant values detected\n"
<< node1->description() << " Input(" << i << ") and "
<< node2->description() << " Input(" << i << ")" << std::endl;
<< description(node1) << " Input(" << i << ") and "
<< description(node2) << " Input(" << i << ")" << std::endl;
}
}

if (should_compare(CmpValues::PRECISIONS)) {
if (node1->input(i).get_element_type() != node2->input(i).get_element_type()) {
err_log << "Different element type detected\n"
<< name(node1) << " Input(" << i << ") "
<< node1->input(i).get_element_type() << " and " << name(node2) << " Input("
<< description(node1) << " Input(" << i << ") "
<< node1->input(i).get_element_type() << " and " << description(node2) << " Input("
<< i << ") " << node2->input(i).get_element_type() << std::endl;
}
}

if (!node1->input(i).get_partial_shape().same_scheme(node2->input(i).get_partial_shape())) {
err_log << "Different shape detected\n"
<< name(node1) << " Input(" << i << ") " << node1->input(i).get_partial_shape()
<< " and " << name(node2) << " Input(" << i << ") "
<< description(node1) << " Input(" << i << ") " << node1->input(i).get_partial_shape()
<< " and " << description(node2) << " Input(" << i << ") "
<< node2->input(i).get_partial_shape() << std::endl;
}

Expand All @@ -707,14 +707,14 @@ void Comparator::compare_inputs(ngraph::Node* node1, ngraph::Node* node2, std::o
auto idx1 = node1->get_input_source_output(i).get_index();
auto idx2 = node2->get_input_source_output(i).get_index();
err_log << "Different ports detected\n"
<< name(node1) << " Input(" << i << ") connected to parent port " << idx1
<< " and " << name(node2) << " Input(" << i << ") connected to parent port "
<< description(node1) << " Input(" << i << ") connected to parent port " << idx1
<< " and " << description(node2) << " Input(" << i << ") connected to parent port "
<< idx2 << std::endl;
}

if (should_compare(CmpValues::RUNTIME_KEYS) && !compare_rt_keys(node1, node2)) {
err_log << "Different runtime info detected\n"
<< name(node1) << " and " << name(node2) << " not equal runtime info."
<< description(node1) << " and " << description(node2) << " not equal runtime info."
<< std::endl;
}
}
Expand All @@ -725,12 +725,12 @@ void Comparator::compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::
const auto& tensor1 = node1->output(i).get_tensor();
const auto& tensor2 = node2->output(i).get_tensor();

if (tensor1.get_names() != tensor2.get_names()) {
err_log << "Output tensors names " << tensor_names(tensor1) << " and "
<< tensor_names(tensor2)
<< " are different for nodes: " << node1->get_friendly_name() << " and "
<< node2->get_friendly_name() << std::endl;
}
//if (tensor1.get_names() != tensor2.get_names()) {
// err_log << "Output tensors names " << tensor_names(tensor1) << " and "
// << tensor_names(tensor2)
// << " are different for nodes: " << node1->get_friendly_name() << " and "
// << node2->get_friendly_name() << std::endl;
//}

if (!node1->output(i).get_partial_shape().same_scheme(
node2->output(i).get_partial_shape())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ template<typename Node>
std::string name(const Node &n) {
return n->get_friendly_name();
}

template<typename Node>
std::string description(const Node& n) {
return n->get_friendly_name() + " (" + (n->get_type_name()) + ")";
}

}
namespace attributes {

Expand Down Expand Up @@ -832,4 +838,4 @@ class CompareNodesAttributes {

Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator::CmpValues comparition_flags);

} // namespace attributes
} // namespace attributes

0 comments on commit c4cdccc

Please sign in to comment.