Skip to content

Commit

Permalink
codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Sep 27, 2023
1 parent 8793b42 commit b372a70
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,25 +140,27 @@ bool align_inputs(const std::shared_ptr<ov::op::v1::StridedSlice>& strided_slice
// `stride` input has to be initialized with 1
input_const_val.resize(expected_size, 1);
}
new_inputs[input_idx-1] = ov::op::v0::Constant::create(input_const->get_element_type(), {input_const_val.size()}, input_const_val);
new_inputs[input_idx - 1] = ov::op::v0::Constant::create(input_const->get_element_type(),
{input_const_val.size()},
input_const_val);

copy_runtime_info(input_const, new_inputs[input_idx-1]);
copy_runtime_info(input_const, new_inputs[input_idx - 1]);
}
}
// connect the new begin, end, stride inputs to StridedSlice operation
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
for (size_t i = 1; i <= new_inputs.size(); ++i) {
if (new_inputs[i-1]) {
strided_slice->input(i).replace_source_output(new_inputs[i-1]);
if (new_inputs[i - 1]) {
strided_slice->input(i).replace_source_output(new_inputs[i - 1]);
}
// insert Gather
strided_slice->input(i).replace_source_output(
ChangeValuesOrder(strided_slice->input_value(i), transpose_order_values, axis));
ChangeValuesOrder(strided_slice->input_value(i), transpose_order_values, axis));
}
return true;
}

}
} // namespace

TSStridedSliceForward::TSStridedSliceForward() {
MATCHER_SCOPE(TSStridedSliceForward);
Expand Down Expand Up @@ -209,14 +211,16 @@ TSStridedSliceForward::TSStridedSliceForward() {

// apply shrink_mask to get the correct order.
// the mask have not to be transposed, so apply transpose 2nd time to get the original order.
auto shrink_axes = convert_mask_to_axis_vec(transpose_mask(strided_slice->get_shrink_axis_mask(), transpose_order_values));
auto shrink_axes =
convert_mask_to_axis_vec(transpose_mask(strided_slice->get_shrink_axis_mask(), transpose_order_values));

transpose_order_values = GetOrderAfterReduction(shrink_axes, transpose_order_values);

// add Transpose op to StridedSlice output
auto new_transpose_order = std::make_shared<ov::op::v0::Constant>(transpose_info.transpose_const->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
auto new_transpose_order =
std::make_shared<ov::op::v0::Constant>(transpose_info.transpose_const->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);

TransposeInputsInfo transpose_input_info = {transpose_info.transpose, new_transpose_order, 0};
strided_slice->validate_and_infer_types();
Expand All @@ -230,14 +234,14 @@ TSStridedSliceForward::TSStridedSliceForward() {
TSStridedSliceBackward::TSStridedSliceBackward() {
MATCHER_SCOPE(TSStridedSliceBackward);

auto main_node_label = wrap_type<ov::op::v1::StridedSlice>([](const Output<Node> &output) -> bool {
auto main_node_label = wrap_type<ov::op::v1::StridedSlice>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && CheckTransposeConsumers(output);
});

auto transpose_const_label = wrap_type<ov::op::v0::Constant>();

auto transpose_label = wrap_type<ov::op::v1::Transpose>({main_node_label, transpose_const_label},
[](const Output<Node> &output) -> bool {
[](const Output<Node>& output) -> bool {
return has_static_rank()(output);
});

Expand All @@ -250,8 +254,8 @@ TSStridedSliceBackward::TSStridedSliceBackward() {
return false;
}

auto strided_slice = ov::as_type_ptr<ov::op::v1::StridedSlice>(main_node);f
if (!strided_slice) {
auto strided_slice = ov::as_type_ptr<ov::op::v1::StridedSlice>(main_node);
f if (!strided_slice) {
return false;
}

Expand Down Expand Up @@ -280,8 +284,8 @@ TSStridedSliceBackward::TSStridedSliceBackward() {
size_t num_elements_in_begin_input = expected_size;

// apply shrink_ mask to get the correct order
auto shrink_axes = convert_shrink_mask_to_axis_vec(strided_slice->get_shrink_axis_mask(),
strided_slice->get_new_axis_mask());
auto shrink_axes =
convert_shrink_mask_to_axis_vec(strided_slice->get_shrink_axis_mask(), strided_slice->get_new_axis_mask());

transpose_order_values = GetOrderBeforeReduction(shrink_axes, transpose_order_values);
if (!align_inputs(strided_slice, transpose_order_values, expected_size, num_elements_in_begin_input)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class StridedSliceFactory : public IFactory {
void set_masks(const StridedSliceMasks& masks) {
m_masks = masks;
}

private:
StridedSliceMasks m_masks;
};
Expand Down Expand Up @@ -109,11 +110,10 @@ auto test_forward_strided_slice = [](const StridedSliceForwardArguments& test_ar
// Reference model description
const auto& ref_transpose_order = test_arguments.reference_transpose_order;
const auto& ref_gather_order = test_arguments.reference_gather_order;
auto new_transpose = [ref_transpose_order](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
auto new_transpose = [ref_transpose_order](const vector<size_t>& idxs,
const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
auto order = make_shared<Constant>(element::i32,
Shape{ref_transpose_order.size()},
ref_transpose_order);
auto order = make_shared<Constant>(element::i32, Shape{ref_transpose_order.size()}, ref_transpose_order);
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
return new_out_vec;
};
Expand All @@ -125,8 +125,8 @@ auto test_forward_strided_slice = [](const StridedSliceForwardArguments& test_ar
}
}


auto update_gather_inputs = [ref_gather_order, new_axes_cnt](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
auto update_gather_inputs = [ref_gather_order, new_axes_cnt](const vector<size_t>& idxs,
const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec = out_vec;
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
size_t expected_size = out_vec[0].get_partial_shape().rank().get_length() + new_axes_cnt;
Expand All @@ -146,11 +146,12 @@ auto test_forward_strided_slice = [](const StridedSliceForwardArguments& test_ar
// `stride` input have to be initialized with 1
input_const_val.resize(expected_size, 1);
}
auto new_input = ov::op::v0::Constant::create(input_const->get_element_type(), {input_const_val.size()}, input_const_val);
auto new_input = ov::op::v0::Constant::create(input_const->get_element_type(),
{input_const_val.size()},
input_const_val);

auto indices = std::make_shared<ov::op::v0::Constant>(element::i32,
Shape{ref_gather_order.size()},
ref_gather_order);
auto indices =
std::make_shared<ov::op::v0::Constant>(element::i32, Shape{ref_gather_order.size()}, ref_gather_order);
new_out_vec[idx] = std::make_shared<ov::op::v8::Gather>(new_input, indices, axis);
}
return new_out_vec;
Expand All @@ -169,10 +170,10 @@ auto test_forward_strided_slice = [](const StridedSliceForwardArguments& test_ar
auto fw_test_1 = []() {
StridedSliceForwardArguments args;
args.inputs_to_main = {
parameter(f32, {7, 10}), // data
constant<int>(i32, {2}, {1, 2}), // begin
constant<int>(i32, {2}, {7, 6}), // end
constant<int>(i32, {2}, {1, 2}) // stride
parameter(f32, {7, 10}), // data
constant<int>(i32, {2}, {1, 2}), // begin
constant<int>(i32, {2}, {7, 6}), // end
constant<int>(i32, {2}, {1, 2}) // stride
};
// empty masks
args.masks.begin = {0, 0};
Expand All @@ -198,10 +199,10 @@ auto fw_test_1 = []() {
auto fw_test_2 = []() {
StridedSliceForwardArguments args;
args.inputs_to_main = {
parameter(f32, {7, 10}), // data
constant<int>(i32, {2}, {1, 2}), // begin
constant<int>(i32, {2}, {7, 6}), // end
constant<int>(i32, {2}, {1, 2}) // stride
parameter(f32, {7, 10}), // data
constant<int>(i32, {2}, {1, 2}), // begin
constant<int>(i32, {2}, {7, 6}), // end
constant<int>(i32, {2}, {1, 2}) // stride
};
// begin and end masks
args.masks.begin = {1, 0};
Expand All @@ -226,15 +227,15 @@ auto fw_test_2 = []() {
auto fw_test_3 = []() {
StridedSliceForwardArguments args;
args.inputs_to_main = {
parameter(f32, {7, 10}), // data
constant<int>(i32, {2}, {1, 2}), // begin
constant<int>(i32, {2}, {7, 6}), // end
constant<int>(i32, {2}, {1, 2}) // stride
parameter(f32, {7, 10}), // data
constant<int>(i32, {2}, {1, 2}), // begin
constant<int>(i32, {2}, {7, 6}), // end
constant<int>(i32, {2}, {1, 2}) // stride
};
// new axis mask
args.masks.begin = {0, 0, 0, 0};
args.masks.end = {0, 0, 0, 0};
args.masks.new_axis = {1, 1, 0 ,0};
args.masks.new_axis = {1, 1, 0, 0};
args.masks.shrink_axis = {0, 0, 0, 0};
args.masks.ellipsis = {0, 0, 0, 0};

Expand All @@ -254,10 +255,10 @@ auto fw_test_3 = []() {
auto fw_test_4 = []() {
StridedSliceForwardArguments args;
args.inputs_to_main = {
parameter(f32, {7, 10}), // data
constant<int>(i32, {2}, {1, 2}), // begin
constant<int>(i32, {2}, {7, 6}), // end
constant<int>(i32, {2}, {1, 2}) // stride
parameter(f32, {7, 10}), // data
constant<int>(i32, {2}, {1, 2}), // begin
constant<int>(i32, {2}, {7, 6}), // end
constant<int>(i32, {2}, {1, 2}) // stride
};
// shrink mask
args.masks.begin = {0, 0};
Expand All @@ -282,10 +283,10 @@ auto fw_test_4 = []() {
auto fw_test_5 = []() {
StridedSliceForwardArguments args;
args.inputs_to_main = {
parameter(f32, {7, 4, 5, 10}), // data
constant<int>(i32, {4}, {1, 2, 2, 1}), // begin
constant<int>(i32, {4}, {5, 4, 4, 4}), // end
constant<int>(i32, {4}, {1, 2, 1, 2}) // stride
parameter(f32, {7, 4, 5, 10}), // data
constant<int>(i32, {4}, {1, 2, 2, 1}), // begin
constant<int>(i32, {4}, {5, 4, 4, 4}), // end
constant<int>(i32, {4}, {1, 2, 1, 2}) // stride
};
// 4dims input, shrink mask
args.masks.begin = {0, 0, 0, 0};
Expand All @@ -310,10 +311,10 @@ auto fw_test_5 = []() {
auto fw_test_6 = []() {
StridedSliceForwardArguments args;
args.inputs_to_main = {
parameter(f32, {7, 10}), // data
constant<int>(i32, {4}, {1, 0, 2, 0}), // begin
constant<int>(i32, {4}, {7, 1, 6, 1}), // end
constant<int>(i32, {4}, {1, 1, 2, 1}) // stride
parameter(f32, {7, 10}), // data
constant<int>(i32, {4}, {1, 0, 2, 0}), // begin
constant<int>(i32, {4}, {7, 1, 6, 1}), // end
constant<int>(i32, {4}, {1, 1, 2, 1}) // stride
};

// mixed masks: new_axis and shrink_mask
Expand All @@ -340,10 +341,10 @@ auto fw_test_6 = []() {
auto fw_test_7 = []() {
StridedSliceForwardArguments args;
args.inputs_to_main = {
parameter(f32, {7, 10}), // data
constant<int>(i32, {4}, {1, 0, 2, 0}), // begin
constant<int>(i32, {4}, {7, 1, 6, 1}), // end
constant<int>(i32, {4}, {1, 1, 2, 1}) // stride
parameter(f32, {7, 10}), // data
constant<int>(i32, {4}, {1, 0, 2, 0}), // begin
constant<int>(i32, {4}, {7, 1, 6, 1}), // end
constant<int>(i32, {4}, {1, 1, 2, 1}) // stride
};

// mixed masks: new_axis and shrink_mask
Expand All @@ -370,10 +371,10 @@ auto fw_test_7 = []() {
auto fw_test_8 = []() {
StridedSliceForwardArguments args;
args.inputs_to_main = {
parameter(f32, {7, 4, 5, 10}), // data
constant<int>(i32, {6}, {0, 0, 1, 2, 2, 1}), // begin
constant<int>(i32, {6}, {1, 1, 5, 4, 4, 4}), // end
constant<int>(i32, {6}, {1, 1, 1, 2, 1, 2}) // stride
parameter(f32, {7, 4, 5, 10}), // data
constant<int>(i32, {6}, {0, 0, 1, 2, 2, 1}), // begin
constant<int>(i32, {6}, {1, 1, 5, 4, 4, 4}), // end
constant<int>(i32, {6}, {1, 1, 1, 2, 1, 2}) // stride
};
// mixed masks: begin, end, shrink, new_axis
args.masks.begin = {0, 0, 0, 1, 0, 0};
Expand All @@ -398,10 +399,10 @@ auto fw_test_8 = []() {
auto fw_test_9 = []() {
StridedSliceForwardArguments args;
args.inputs_to_main = {
parameter(f32, {7, 4, 5, 10}), // data
constant<int>(i32, {4}, {1, 2, 2, 1}), // begin
constant<int>(i32, {4}, {5, 4, 4, 4}), // end
constant<int>(i32, {4}, {1, 2, 1, 2}) // stride
parameter(f32, {7, 4, 5, 10}), // data
constant<int>(i32, {4}, {1, 2, 2, 1}), // begin
constant<int>(i32, {4}, {5, 4, 4, 4}), // end
constant<int>(i32, {4}, {1, 2, 1, 2}) // stride
};
// mixed masks: shrink, new_axis
args.masks.begin = {0, 0, 0, 0};
Expand Down Expand Up @@ -451,11 +452,10 @@ auto test_backward_strided_slice = [](const StridedSliceForwardArguments& test_a
// Reference model description
const auto& ref_transpose_order = test_arguments.reference_transpose_order;
const auto& ref_gather_order = test_arguments.reference_gather_order;
auto new_transpose = [ref_transpose_order](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
auto new_transpose = [ref_transpose_order](const vector<size_t>& idxs,
const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec = out_vec;
auto order = make_shared<Constant>(element::i32,
Shape{ref_transpose_order.size()},
ref_transpose_order);
auto order = make_shared<Constant>(element::i32, Shape{ref_transpose_order.size()}, ref_transpose_order);
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
return new_out_vec;
};
Expand All @@ -467,7 +467,8 @@ auto test_backward_strided_slice = [](const StridedSliceForwardArguments& test_a
}
}

auto update_gather_inputs = [ref_gather_order, new_axes_cnt](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
auto update_gather_inputs = [ref_gather_order, new_axes_cnt](const vector<size_t>& idxs,
const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec = out_vec;
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
size_t expected_size = out_vec[0].get_partial_shape().rank().get_length() + new_axes_cnt;
Expand All @@ -487,11 +488,12 @@ auto test_backward_strided_slice = [](const StridedSliceForwardArguments& test_a
// `stride` input have to be initialized with 1
input_const_val.resize(expected_size, 1);
}
auto new_input = ov::op::v0::Constant::create(input_const->get_element_type(), {input_const_val.size()}, input_const_val);
auto new_input = ov::op::v0::Constant::create(input_const->get_element_type(),
{input_const_val.size()},
input_const_val);

auto indices = std::make_shared<ov::op::v0::Constant>(element::i32,
Shape{ref_gather_order.size()},
ref_gather_order);
auto indices =
std::make_shared<ov::op::v0::Constant>(element::i32, Shape{ref_gather_order.size()}, ref_gather_order);
new_out_vec[idx] = std::make_shared<ov::op::v8::Gather>(new_input, indices, axis);
}
return new_out_vec;
Expand Down

0 comments on commit b372a70

Please sign in to comment.