From e2ebe0bacc4462884351797cf6eadb34ceb5b56d Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 21 Sep 2023 20:42:28 +0400 Subject: [PATCH] apply review comments --- .../transforms/prim_list_unpack_replacer.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp index f759a4d6f0b740..ae7a39674e99d1 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp @@ -180,21 +180,18 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { "aten::broadcast_tensors: only prim::ListConstruct supported as input."); return false; } - auto zero = opset10::Constant::create(element::i32, Shape{}, {0}); - Output final_shape_t = zero; + Output final_shape_t = opset10::Constant::create(element::f32, Shape{}, {0}); + ; for (auto input : tensors->inputs()) { - auto tensor_shape = std::make_shared(input.get_source_output()); - auto zero_broadcasted = - std::make_shared(zero, tensor_shape, ov::op::BroadcastType::BIDIRECTIONAL); - final_shape_t = std::make_shared(final_shape_t, zero_broadcasted); + auto tensor = rg.make(input.get_source_output(), element::f32); + final_shape_t = rg.make(final_shape_t, tensor); } - auto final_shape = std::make_shared(final_shape_t, element::i32); + auto final_shape = rg.make(final_shape_t, element::i32); OutputVector outputs; for (auto input : tensors->inputs()) { - outputs.push_back(std::make_shared(input.get_source_output(), - final_shape, - ov::op::BroadcastType::BIDIRECTIONAL)); + outputs.push_back(rg.make(input.get_source_output(), final_shape)); } + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); replace_node(list_unpack, outputs); return true; }