Skip to content

Commit

Permalink
multi target inputs support #2
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Apr 11, 2023
1 parent ad46004 commit e99bd39
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 63 deletions.
89 changes: 60 additions & 29 deletions src/common/snippets/src/pass/propagate_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,82 +207,113 @@ bool ngraph::snippets::pass::PropagatePrecision::validate_and_infer_types_and_re
op->validate_and_infer_types();
}

// one flag for both options: only one result consumer
bool result_consumer_was_handled = false;
for (size_t i = 0; i < op->get_output_size(); ++i) {
auto output = op->output(i);

if (output.get_element_type() != op_output_types[i]) {
was_updated = true;
auto inputs_to_handle = output.get_target_inputs();
// if at least consumer is result then we should add conversion
// in this case we can not remove existing conversion with result consumers
const auto least_consumer_is_result = std::any_of(
inputs_to_handle.begin(),
inputs_to_handle.end(),
[](const Input<Node>& input) { return ngraph::is_type<ngraph::op::Result>(input.get_node()); });

// option #1: if convertion after the operation exists then check if possible to remove it instead add new
bool converts_can_be_removed = true;
const auto& target_inputs = output.get_target_inputs();
for (const auto& target_input : target_inputs) {
const auto& child = target_input.get_node();
auto it = inputs_to_handle.begin();
while (it != inputs_to_handle.end()) {
const auto& child = it->get_node();
const auto existing_convert = ov::as_type<ngraph::snippets::op::ConvertSaturation>(child);
if (existing_convert == nullptr) {
converts_can_be_removed = false;
break;
it++;
continue;
}

const element::Type actual_before = output.get_element_type();
const element::Type actual_after = op_output_types[i];
const element::Type required_after = existing_convert->output(0).get_element_type();
if (!can_be_removed(actual_before, actual_after, required_after)) {
converts_can_be_removed = false;
break;
}
}
if (converts_can_be_removed) {
for (const auto& target_input : target_inputs) {
const auto& child = target_input.get_node();
const auto existing_convert = ov::as_type<ngraph::snippets::op::ConvertSaturation>(child);
assert(existing_convert != nullptr);
assert(can_be_removed(output.get_element_type(), op_output_types[i], existing_convert->output(0).get_element_type()));
if (can_be_removed(actual_before, actual_after, required_after)) {
const auto convert_target_inputs = existing_convert->output(0).get_target_inputs();

// before removal we should check result cunsumers
if (least_consumer_is_result && std::any_of(
convert_target_inputs.begin(),
convert_target_inputs.end(),
[](const Input<Node>& input) { return ngraph::is_type<ngraph::op::Result>(input.get_node()); })) {
// op has result consumer, not possible to remove convertion to have additional one
it++;
continue;
}

// output.get_target_inputs().size() is increasing in line below
// as result we need to remove existing child_input manually
// issue #107966
existing_convert->output(0).replace(output);
output.remove_target_input(target_input);
output.remove_target_input(*it);

for (auto& input : output.get_target_inputs()) {
// use only affected inputs of conversion
for (auto& input : convert_target_inputs) {
if (ngraph::is_type<ngraph::op::Result>(input.get_node())) {
if (result_consumer_was_handled) {
// only one result consumer is supported
break;
}
result_consumer_was_handled = true;
op->set_friendly_name(existing_convert->get_friendly_name());

// Result input tensor name was changed, the name has to be restored
// issue #107826
input.get_tensor_ptr()->add_names(output.get_tensor_ptr()->get_names());
op->set_friendly_name(existing_convert->get_friendly_name());
}
}
}

const auto current_it = it;
it++;
inputs_to_handle.erase(current_it);
}
if (inputs_to_handle.empty()) {
continue;
}

// option #2: add new convertion after the operation
// option #2: add new convertion after the operation for all not handled inputs before
bool tensor_has_to_be_reset = false;
auto convert = std::make_shared<ngraph::snippets::op::ConvertSaturation>(
output,
op_output_types[i]);
ngraph::copy_runtime_info(output.get_node_shared_ptr(), convert);

for (auto& input : output.get_target_inputs()) {
for (const auto& input : inputs_to_handle) {
auto child = input.get_node();
if (child == convert.get()) {
continue;
}

input.replace_source_output(convert->output(0));


if (ngraph::is_type<ngraph::op::Result>(input.get_node())) {
// Result input tensor name was changed, the name has to be restored
// issue #107826
input.get_tensor_ptr()->add_names(output.get_tensor_ptr()->get_names());

if (result_consumer_was_handled) {
// only one result consumer is supported
break;
}
result_consumer_was_handled = true;
const std::string original_name = op->get_friendly_name();
op->set_friendly_name(original_name + "_original");
convert->set_friendly_name(original_name);

// Result input tensor name was changed, the name has to be restored
// issue #107826
input.get_tensor_ptr()->add_names(output.get_tensor_ptr()->get_names());
tensor_has_to_be_reset = true;
}
}
output.get_tensor_ptr()->set_names({});
if (tensor_has_to_be_reset) {
// Result input tensor name was changed, the previous operation tensor has to be reset
// issue #107826
output.get_tensor_ptr()->set_names({});
}
}
}

Expand Down
Loading

0 comments on commit e99bd39

Please sign in to comment.