Skip to content

Commit

Permalink
fix unit tests execution
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Mar 6, 2023
1 parent c4dd724 commit aa3cc35
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,6 @@ NodePtr InsertUnsqueeze(Output<Node> node, size_t n_dims) {
return unsqueeze;
}

Output<Node> FixInputNodeRank(Output<Node> input_node, Rank::value_type required_rank) {
const Rank::value_type output_rank = input_node.get_partial_shape().rank().get_length();
if (output_rank >= required_rank)
return input_node;
return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0);
}

/*
Converts gather indices to positive form
*/
Expand Down Expand Up @@ -234,7 +227,7 @@ void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_i
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{1},
Shape{},
gather_positive_axis);

auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
Expand Down Expand Up @@ -263,7 +256,7 @@ NodeVector InsertOutputGather(NodePtr main_node, const GatherInputsInfo& gather_
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
main_node->output(i).get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{1},
Shape{},
gather_positive_axis);
auto new_gather = std::make_shared<Gather>(main_node->output(i), new_indices_const, new_axis_const);

Expand Down Expand Up @@ -317,7 +310,7 @@ NodeVector InsertGatherBeforeNode(NodePtr main_node,
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{1},
Shape{},
gather_positive_axis);

auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ std::shared_ptr<Gather> MakeGather(NodePtr input_node, CreateIndicesF create_ind
const std::vector<size_t> indexes = create_indices_func(input_shape[axis], 0);
auto gather_indexes_node = Constant::create(ngraph::element::i64, ov::Shape{indexes.size()}, indexes);

auto gather_axis_node = Constant::create(ngraph::element::i64, ngraph::Shape{1}, {axis});
auto gather_axis_node = Constant::create(ngraph::element::i64, ngraph::Shape{}, {axis});

return std::make_shared<Gather>(input_node, gather_indexes_node, gather_axis_node);
}
Expand Down

0 comments on commit aa3cc35

Please sign in to comment.