Skip to content

Commit

Permalink
correct concat axis handling
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Jul 25, 2024
1 parent c6ef20d commit cc0ac2c
Showing 1 changed file with 18 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,21 @@ TSConcatForward::TSConcatForward() {
return false;
}

if (concat_node->get_output_partial_shape(0).rank().is_dynamic()) {
return false;
auto concat_axis = concat_node->get_axis();
if (concat_axis < 0) {
if (concat_node->get_output_partial_shape(0).rank().is_dynamic()) {
return false;
}
const auto rank = concat_node->get_output_partial_shape(0).rank().get_length();
concat_axis = ov::util::normalize(concat_axis, rank);
}

// todo: support dyn rank case
bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_info);
if (!updated) {
return false;
}

const auto rank = concat_node->get_output_partial_shape(0).rank().get_length();
const auto concat_axis = ov::util::normalize(concat_node->get_axis(), rank);

const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
concat_node->set_axis(transposed_concat_axis);
Expand Down Expand Up @@ -83,12 +86,19 @@ TSConcatBackward::TSConcatBackward() {
}

auto concat_node = as_type_ptr<ov::op::v0::Concat>(main_node);
if (concat_node->get_output_partial_shape(0).rank().is_dynamic()) {
if (!concat_node) {
return false;
}

const auto rank = concat_node->get_output_partial_shape(0).rank().get_length();
auto concat_axis = ov::util::normalize(concat_node->get_axis(), rank);
auto concat_axis = concat_node->get_axis();
if (concat_axis < 0) {
if (concat_node->get_output_partial_shape(0).rank().is_dynamic()) {
return false;
}

const auto rank = concat_node->get_output_partial_shape(0).rank().get_length();
concat_axis = ov::util::normalize(concat_axis, rank);
}

const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_axis_order = ReverseTransposeOrder(transpose_axis_order);
Expand Down

0 comments on commit cc0ac2c

Please sign in to comment.