Skip to content

Commit

Permalink
[transpose sinking] fix SIGSEGV TSDataMovement (#24132)
Browse files Browse the repository at this point in the history
### Details:
 - check constant is not empty

### Tickets:
 - CVS-138875
  • Loading branch information
evkotov authored May 16, 2024
1 parent 7d7fdbe commit 1a87846
Show file tree
Hide file tree
Showing 21 changed files with 111 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ class ov::pass::transpose_sinking::TSForwardBase : public ov::pass::MatcherPass
TSForwardBase() = default;

template <class... Types>
void create_pattern(bool const_transpose_input, std::vector<size_t> transpose_indices = {}) {
m_const_transpose_input = const_transpose_input;
m_tranpose_indices = std::move(transpose_indices);
void create_pattern(std::vector<size_t> transpose_indices = {},
const std::function<bool(const std::shared_ptr<ov::op::v1::Transpose>& transpose,
const std::shared_ptr<ov::op::v0::Constant>& transpose_order)>&
if_transpose_sinkable = utils::if_transpose_sinkable_default) {
m_if_transpose_sinkable = if_transpose_sinkable;
m_transpose_indices = std::move(transpose_indices);
m_pattern = ov::pass::pattern::wrap_type<Types...>([&](const Output<Node>& output) -> bool {
return if_node_has_transpose_inputs(output, m_const_transpose_input, m_tranpose_indices);
return if_node_has_transpose_inputs(output, m_transpose_indices, m_if_transpose_sinkable);
});
}

Expand All @@ -53,11 +56,15 @@ class ov::pass::transpose_sinking::TSForwardBase : public ov::pass::MatcherPass
const utils::TransposeInputsInfo& transpose_info);

private:
static bool if_node_has_transpose_inputs(const Output<Node>& output,
bool const_transpose_input,
const std::vector<size_t>& transpose_indices);
static bool if_node_has_transpose_inputs(
const Output<Node>& output,
const std::vector<size_t>& transpose_indices,
const std::function<bool(const std::shared_ptr<ov::op::v1::Transpose>& transpose,
const std::shared_ptr<ov::op::v0::Constant>& transpose_order)>&);

std::shared_ptr<Node> m_pattern;
bool m_const_transpose_input = true;
std::vector<size_t> m_tranpose_indices;
};
std::function<bool(const std::shared_ptr<ov::op::v1::Transpose>& transpose,
const std::shared_ptr<ov::op::v0::Constant>& transpose_order)>
m_if_transpose_sinkable = utils::if_transpose_sinkable_default;
std::vector<size_t> m_transpose_indices;
};
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,22 @@ struct TransposeInputsInfo {
}
};

/**
* @brief default function to check if we could sink found transpose
*/
bool if_transpose_sinkable_default(const std::shared_ptr<ov::op::v1::Transpose>& transpose,
const std::shared_ptr<ov::op::v0::Constant>& transpose_order);

/**
* @brief Finds node first input that is a transpose operation and returns filled TransposeInputsInfo
* for it
*/
TransposeInputsInfo GetFirstTransposeInput(const std::shared_ptr<ov::Node>&,
bool const_transpose_order,
const std::vector<size_t>& indices = {});
TransposeInputsInfo GetFirstTransposeInput(
const std::shared_ptr<ov::Node>&,
const std::vector<size_t>& indices = {},
const std::function<bool(const std::shared_ptr<ov::op::v1::Transpose>& transpose,
const std::shared_ptr<ov::op::v0::Constant>& transpose_order)>& =
if_transpose_sinkable_default);

/**
* @brief Checks if @arg has any input node that is a transpose operation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void TSForwardBase::transpose_sinking(const std::string& pass_name,
const auto& pattern_to_output = m.get_pattern_value_map();
auto main_node = pattern_to_output.at(m_pattern).get_node_shared_ptr();
utils::TransposeInputsInfo transpose_input_info =
utils::GetFirstTransposeInput(main_node, m_const_transpose_input, m_tranpose_indices);
utils::GetFirstTransposeInput(main_node, m_transpose_indices, m_if_transpose_sinkable);

if (transformation_callback(main_node)) {
mark_as_no_sinking_node(transpose_input_info.transpose);
Expand Down Expand Up @@ -68,10 +68,12 @@ void TSForwardBase::default_outputs_update(const std::shared_ptr<Node>& main_nod
}
}

bool TSForwardBase::if_node_has_transpose_inputs(const Output<Node>& output,
bool const_transpose_input,
const std::vector<size_t>& transpose_indices) {
bool TSForwardBase::if_node_has_transpose_inputs(
const Output<Node>& output,
const std::vector<size_t>& transpose_indices,
const std::function<bool(const std::shared_ptr<ov::op::v1::Transpose>& transpose,
const std::shared_ptr<ov::op::v0::Constant>& transpose_order)>& if_transpose_sinkable) {
utils::TransposeInputsInfo inputs_info =
utils::GetFirstTransposeInput(output.get_node_shared_ptr(), const_transpose_input, transpose_indices);
utils::GetFirstTransposeInput(output.get_node_shared_ptr(), transpose_indices, if_transpose_sinkable);
return !inputs_info.isEmpty();
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ TSBinaryForward::TSBinaryForward() : TSForwardBase() {
op::util::BinaryElementwiseComparison,
op::util::BinaryElementwiseLogical,
ov::op::v0::PRelu,
ov::op::v0::FakeQuantize>(true);
ov::op::v0::FakeQuantize>();
transpose_sinking(matcher_name);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using namespace ov::pass::transpose_sinking::utils;
TSConcatForward::TSConcatForward() {
MATCHER_SCOPE(TSConcatForward);

create_pattern<ov::op::v0::Concat>(true);
create_pattern<ov::op::v0::Concat>();

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using namespace ov::pass::transpose_sinking::utils;
TSCumSumForward::TSCumSumForward() {
MATCHER_SCOPE(TSCumSumForward);

create_pattern<ov::op::v0::CumSum>(true, {0});
create_pattern<ov::op::v0::CumSum>({0});

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ std::vector<size_t> get_indices_by_op_type(const std::shared_ptr<Node>& main_nod
TSDataMovementForward::TSDataMovementForward() {
MATCHER_SCOPE(TSDataMovementForward);
create_pattern<op::util::PadBase, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
true,
{0});

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using namespace ov::pass::transpose_sinking::utils;
TSGatherForward::TSGatherForward() {
MATCHER_SCOPE(TSGatherForward);

create_pattern<ov::op::v8::Gather>(true, {0});
create_pattern<ov::op::v8::Gather>({0});

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using namespace ov::pass::transpose_sinking::utils;

TSInterpolateForward::TSInterpolateForward() {
MATCHER_SCOPE(TSInterpolateForward);
create_pattern<ov::op::v4::Interpolate>(true, {0});
create_pattern<ov::op::v4::Interpolate>({0});

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ bool get_keep_dims(const std::shared_ptr<Node>& main_node) {
TSReductionForward::TSReductionForward() {
MATCHER_SCOPE(TSReductionForward);

create_pattern<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(true, {0});
create_pattern<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>({0});
auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
auto keep_dims = get_keep_dims(main_node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ov::pass::transpose_sinking::TSShapeOfForward::TSShapeOfForward() {
MATCHER_SCOPE(TSShapeOfForward);

create_pattern<op::util::ShapeOfBase>(true);
create_pattern<op::util::ShapeOfBase>();
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
const utils::TransposeInputsInfo& transpose_info) -> bool {
main_node->input(0).replace_source_output(transpose_info.transpose->input_value(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using namespace ov::pass::transpose_sinking::utils;

TSSliceForward::TSSliceForward() {
MATCHER_SCOPE(TSSliceForward);
create_pattern<ov::op::v8::Slice>(true, {0});
create_pattern<ov::op::v8::Slice>({0});

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ bool GetSplitAxis(const std::shared_ptr<ov::op::v0::Constant>& split_axis, const
TSSplitForward::TSSplitForward() {
MATCHER_SCOPE(TSSplitForward);

create_pattern<ov::op::v1::Split, ov::op::v1::VariadicSplit>(true, {0});
create_pattern<ov::op::v1::Split, ov::op::v1::VariadicSplit>({0});

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ bool squeeze_axes_to_shape(const Output<Node>& input_node,
TSSqueezeForward::TSSqueezeForward() {
MATCHER_SCOPE(TSSqueezeForward);

create_pattern<ov::op::v0::Squeeze, ov::op::v1::Reshape>(true, {0});
create_pattern<ov::op::v0::Squeeze, ov::op::v1::Reshape>({0});

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using namespace ov::pass::transpose_sinking::utils;
TSTileForward::TSTileForward() {
MATCHER_SCOPE(TSTileForward);

create_pattern<ov::op::v0::Tile>(true, {0});
create_pattern<ov::op::v0::Tile>({0});

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {

using NodePtr = std::shared_ptr<ov::Node>;
using NodePair = std::pair<NodePtr, NodePtr>;

bool if_transpose_sinkable(const std::shared_ptr<ov::op::v1::Transpose>& transpose,
const std::shared_ptr<ov::op::v0::Constant>& transpose_order) {
return static_cast<bool>(transpose);
}

} // namespace

Expand All @@ -59,7 +63,7 @@ TSUnaryForward::TSUnaryForward() {
ov::op::v4::Swish,
ov::op::v0::HardSigmoid,
ov::op::v5::LogSoftmax,
ov::op::v1::ConvertLike>(true, {0});
ov::op::v1::ConvertLike>({0}, if_transpose_sinkable);
auto ts_unary_sinking_function = [this](const std::shared_ptr<Node>& main_node,
const utils::TransposeInputsInfo& transpose_info) -> bool {
bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, {0});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ bool unsqueeze_axes_to_shape(const Output<Node>& input_node,
TSUnsqueezeForward::TSUnsqueezeForward() {
MATCHER_SCOPE(TSUnsqueezeForward);

create_pattern<ov::op::v0::Unsqueeze, ov::op::v1::Reshape>(true, {0});
create_pattern<ov::op::v0::Unsqueeze, ov::op::v1::Reshape>({0});

auto sinking_transformation = [OV_CAPTURE_CPY_AND_THIS](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,22 @@ Output<Node> ChangeAxes(const Output<Node>& indices,
return ChangeAxes(indices, data, axis);
}

TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node,
bool const_transpose_order,
const std::vector<size_t>& indices) {
bool if_transpose_sinkable_default(const std::shared_ptr<ov::op::v1::Transpose>& transpose,
const std::shared_ptr<ov::op::v0::Constant>& transpose_order) {
if (!transpose || !transpose_order)
return false;
const auto partial_shape_rank = transpose->get_input_partial_shape(0).rank();
const auto order = transpose_order->get_axis_vector_val();
if (partial_shape_rank.is_dynamic() && order.empty())
return false;
return true;
}

TransposeInputsInfo GetFirstTransposeInput(
const NodePtr& node,
const std::vector<size_t>& indices,
const std::function<bool(const std::shared_ptr<ov::op::v1::Transpose>& transpose,
const std::shared_ptr<ov::op::v0::Constant>& transpose_order)>& if_transpose_sinkable) {
auto indices_to_check = indices;
if (indices.empty()) {
indices_to_check.resize(node->get_input_size());
Expand All @@ -83,7 +96,7 @@ TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node,
if (!transpose_node)
continue;
auto constant_node = as_type_ptr<ov::op::v0::Constant>(transpose_node->input_value(1).get_node_shared_ptr());
if (const_transpose_order && !constant_node)
if (!if_transpose_sinkable(transpose_node, constant_node))
continue;
{
TransposeInputsInfo input_info;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1325,3 +1325,18 @@ TEST_F(TransformationTestsF, TSBinaryBackwardPReluSlabSpecialRank1) {

manager.register_pass<TSBinaryBackward>();
}

TEST_F(TransformationTestsF, TSBinaryForwardDynamic) {
auto X = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
auto ts_order = std::make_shared<Constant>(element::u64, Shape{0}, Shape{});
auto transpose = std::make_shared<Transpose>(X, ts_order);

auto c1 = std::make_shared<Constant>(element::f32, Shape{0}, Shape{});

auto add = std::make_shared<Add>(transpose, c1);

model = std::make_shared<Model>(ov::OutputVector{add}, ov::ParameterVector{X});
model_ref = model->clone();

manager.register_pass<TSBinaryForward>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,6 @@ auto test_backward_reshape_unsqueeze = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
TSTestFixture,
test_backward_reshape_unsqueeze());

} // namespace common
} // namespace testing
} // namespace transpose_sinking
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,30 @@ INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
transpose_sinking::testing::unary::test_forward_multiple_consumers_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);

TEST_F(TransformationTestsF, TSUnaryForwardDynamic) {
{
auto X = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
auto ts_order = std::make_shared<Constant>(element::u64, Shape{0}, Shape{});
auto transpose = std::make_shared<Transpose>(X, ts_order);

auto tanh = std::make_shared<Tanh>(transpose);

model = std::make_shared<Model>(ov::OutputVector{tanh}, ov::ParameterVector{X});

manager.register_pass<TSUnaryForward>();
}
{
auto X = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());

auto tanh = std::make_shared<Tanh>(X);

auto ts_order = std::make_shared<Constant>(element::u64, Shape{0}, Shape{});
auto transpose = std::make_shared<Transpose>(tanh, ts_order);

model_ref = std::make_shared<Model>(ov::OutputVector{transpose}, ov::ParameterVector{X});
}
}

} // namespace unary
} // namespace testing
} // namespace transpose_sinking
} // namespace transpose_sinking

0 comments on commit 1a87846

Please sign in to comment.