diff --git a/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp b/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp index 828a2db5981a89..bcd71a9fd296b6 100644 --- a/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp +++ b/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp @@ -87,13 +87,6 @@ NodePtr InsertUnsqueeze(Output node, size_t n_dims) { return unsqueeze; } -Output FixInputNodeRank(Output input_node, Rank::value_type required_rank) { - const Rank::value_type output_rank = input_node.get_partial_shape().rank().get_length(); - if (output_rank >= required_rank) - return input_node; - return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0); -} - /* Converts gather indices to positive form */ @@ -234,7 +227,7 @@ void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_i const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis, input_node.get_partial_shape().rank().get_length()); auto new_axis_const = std::make_shared(axis_element_type, - Shape{1}, + Shape{}, gather_positive_axis); auto new_gather = std::make_shared(input_node, new_indices_const, new_axis_const); @@ -263,7 +256,7 @@ NodeVector InsertOutputGather(NodePtr main_node, const GatherInputsInfo& gather_ const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis, main_node->output(i).get_partial_shape().rank().get_length()); auto new_axis_const = std::make_shared(axis_element_type, - Shape{1}, + Shape{}, gather_positive_axis); auto new_gather = std::make_shared(main_node->output(i), new_indices_const, new_axis_const); @@ -317,7 +310,7 @@ NodeVector InsertGatherBeforeNode(NodePtr main_node, const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis, input_node.get_partial_shape().rank().get_length()); auto new_axis_const = std::make_shared(axis_element_type, - Shape{1}, + Shape{}, gather_positive_axis); auto new_gather = std::make_shared(input_node, new_indices_const, new_axis_const); diff --git a/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_binary_test.cpp b/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_binary_test.cpp index 9955505c60bcda..3fbfb42c4c270f 100644 --- a/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_binary_test.cpp +++ b/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_binary_test.cpp @@ -81,7 +81,7 @@ std::shared_ptr MakeGather(NodePtr input_node, CreateIndicesF create_ind const std::vector indexes = create_indices_func(input_shape[axis], 0); auto gather_indexes_node = Constant::create(ngraph::element::i64, ov::Shape{indexes.size()}, indexes); - auto gather_axis_node = Constant::create(ngraph::element::i64, ngraph::Shape{1}, {axis}); + auto gather_axis_node = Constant::create(ngraph::element::i64, ngraph::Shape{}, {axis}); return std::make_shared(input_node, gather_indexes_node, gather_axis_node); }