Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbol Tracking API updated and made public #23136

Merged
merged 8 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2021 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,7 +10,8 @@
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

using P2Btype = std::unordered_map<std::shared_ptr<ov::opset1::Parameter>, std::unordered_set<ov::label_t>>;
using P2Btype =
std::unordered_map<std::shared_ptr<ov::opset1::Parameter>, std::unordered_set<std::shared_ptr<ov::Symbol>>>;

namespace ov {
namespace pass {
Expand All @@ -33,22 +34,21 @@ class ov::pass::FindBatch : public ov::pass::ModelPass {
};

namespace ov {
class DimensionTracker;

namespace batch_util {
void mark_batch(const std::shared_ptr<ov::opset1::Parameter>& parameter,
P2Btype& map,
const std::unordered_set<label_t>& batches);
const std::unordered_set<std::shared_ptr<Symbol>>& batches);
void mark_no_batch(const std::shared_ptr<ov::opset1::Parameter>& parameter, P2Btype& map);
void mark_layout_independent_batch(const std::shared_ptr<ov::opset1::Parameter>& parameter,
const std::shared_ptr<ov::Node>& result,
P2Btype& map);
void mark_with_unique_dimension_labels(const std::shared_ptr<Model>& m, const ov::DimensionTracker& dt);
void mark_with_unique_dimension_symbols(const std::shared_ptr<Model>& m);
void restore_original_dimensions(
const std::shared_ptr<ov::Model>& model,
const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape,
bool leave_batch_dynamic = true,
bool clear_labels = false);
bool clear_symbols = false);
bool check_batch_tracks_through_all_the_nodes(const std::shared_ptr<ov::Model>& m);
P2Btype find_batch(const std::shared_ptr<ov::Model>& m);
bool detach_detection_output(const std::shared_ptr<ov::Model>& f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TRANSFORMATIONS_API ChainedMaximumOptimization;

/**
* @ingroup ov_transformation_common_api
* @brief Optimizes graphs based on value labels / symbols
* @brief Optimizes graphs based on value symbols
* Maximum(Maximum(A, B), B) -> Maximum(A, B)
* Maximum(Maximum(A, B), A) -> Maximum(A, B)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TRANSFORMATIONS_API DeReshapeFullyConnected;

/**
* @ingroup ov_transformation_common_api
* @brief Transformation uses symbol / label information to optimize out Reshape operations surrounding MatMul.
* @brief Transformation uses symbol information to optimize out Reshape operations surrounding MatMul.
* It checks that surrounding Reshapes are only manipulating with batch dimensions of tensor in a do-undo kind of way.
*
* Example:
Expand Down Expand Up @@ -46,9 +46,9 @@ class TRANSFORMATIONS_API DeReshapeFullyConnected;
* Binary Elementwise Arithmetic operation without second input scalar restriction.
* MatMul -[-> BEA -]-> Reshape
* this pattern variation is only applicable for the case when input reshapes are 4D -> 3D and output reshape is 3D ->
* 4D. Additionally, shape labels on output of MatMul should be equal to the input shape labels of the last Reshape,
* 4D. Additionally, shape symbols on output of MatMul should be equal to the input shape symbols of the last Reshape,
* meaning that this Binary Elementwise Arithmetic doesn't perform any broadcasting of input coming from MatMul -- only
* other input may be broadcasted to the MatMul input of this BEA. This effect (equality of MatMul output shape labels
* other input may be broadcasted to the MatMul input of this BEA. This effect (equality of MatMul output shape symbols
* and output shape of BEA) is being handled by LabelResolvingThroughSelect transformation in the particular models
* that this variation targets.
*
Expand All @@ -68,7 +68,7 @@ class ov::pass::DeReshapeMatMul : public ov::pass::MatcherPass {

/**
* @ingroup ov_transformation_common_api
* @brief Transformation uses symbol / label information to optimize out Reshape operations surrounding special cases of
* @brief Transformation uses symbol information to optimize out Reshape operations surrounding special cases of
* MatMul. It checks that surrounding Reshapes are only manipulating with batch dimensions of tensor in a do-undo kind
* of way. The difference with previous optimization is that this case has Reshape only on one input of MatMul and the
* other input is strictly 2D. Such MatMuls are also called FullyConnected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@

namespace ov {
namespace pass {
class TRANSFORMATIONS_API ApplyTableOfEquivalence;
class TRANSFORMATIONS_API OptimizeLabelsUsedAsValues;
class TRANSFORMATIONS_API ApplySymbolEquivalence;
class TRANSFORMATIONS_API OptimizeSymbolsUsedAsValues;
} // namespace pass
} // namespace ov

/**
* @ingroup ov_transformation_common_api
* @brief Resets symbols / labels on output shapes and values according to table of symbol / label equivalence. It
* allows to reduce number of labels used in the model and to disambiguate label values.
* @brief Resets symbols on output shapes and values according to symbol equivalence. It
* allows to reduce number of labels used in the model and to disambiguate symbol values.
*/
class ov::pass::ApplyTableOfEquivalence : public ov::pass::ModelPass {
class ov::pass::ApplySymbolEquivalence : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ApplyTableOfEquivalence", "0");
OPENVINO_RTTI("ApplySymbolEquivalence", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};

/**
* @ingroup ov_transformation_common_api
* @brief Collects sources where each symbol / label initially appeared (on shape or shape sub-graph) and attaches all
* @brief Collects sources where each symbol initially appeared (on shape or shape sub-graph) and attaches all
* value usages of this label to this initial source
*/
class ov::pass::OptimizeLabelsUsedAsValues : public ov::pass::ModelPass {
class ov::pass::OptimizeSymbolsUsedAsValues : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("OptimizeLabelsUsedAsValues", "0");
OPENVINO_RTTI("OptimizeSymbolsUsedAsValues", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@ class ov::pass::SymbolicOptimizations : public ov::pass::ModelPass {
class ov::pass::SymbolicPropagation : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("SymbolicPropagation");
SymbolicPropagation();
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;

private:
std::shared_ptr<ov::TableOfEquivalence> m_te;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,37 @@ namespace ov {
namespace symbol {
namespace util {

/// \brief Collects labels from shape. Labels of static dimensions are guaranteed to be ov::no_labels
/// \brief Collects symbols from shape. Symbols of static dimensions are guaranteed to be nullptr
///
/// \param shape Shape object to collect labels from
/// \param labels TensorLabel object to collect labels to
/// \param shape Shape object to collect symbols from
/// \param symbols TensorSymbol object to collect symbols to
///
/// \return Status of collecting the labels (false if rank is static else true)
TRANSFORMATIONS_API bool get_labels(const ov::PartialShape& shape, ov::TensorLabel& labels);
/// \return Status of collecting the symbols (false if rank is static else true)
TRANSFORMATIONS_API bool get_symbols(const ov::PartialShape& shape, ov::TensorSymbol& symbols);

/// \brief Collects labels from tensor of Output object
/// \brief Collects symbols from tensor of Output object
///
/// \param output Output object to collect labels from
/// \param labels TensorLabel object to collect labels to
/// \param output Output object to collect symbols from
/// \param symbols TensorSymbol object to collect symbols to
///
/// \return Status of collecting the labels (false if tensor has no labels else true)
TRANSFORMATIONS_API bool get_labels(const ov::Output<ov::Node>& output, ov::TensorLabel& labels);
/// \return Status of collecting the symbols (false if tensor has no symbols else true)
TRANSFORMATIONS_API bool get_symbols(const ov::Output<ov::Node>& output, ov::TensorSymbol& symbols);

/// \brief Compares
///
/// \param lhs TensorLabel object to compare
/// \param rhs TensorLabel object to compare
/// \param lhs TensorSymbol object to compare
/// \param rhs TensorSymbol object to compare
///
/// \return true if labels are unique and equal between lhs and rhs else false
TRANSFORMATIONS_API bool are_unique_and_equal_labels(const ov::TensorLabel& lhs, const ov::TensorLabel& rhs);
/// \return true if symbols are unique and equal between lhs and rhs else false
TRANSFORMATIONS_API bool are_unique_and_equal_symbols(const ov::TensorSymbol& lhs, const ov::TensorSymbol& rhs);

/// \brief Compares dimensions: if dimensions are static compares values of dimensions, if dimensions are dynamic
/// compares their respective labels using TableOfEquivalence
/// compares their respective symbols
///
/// \param lhs Dimension object to compare
/// \param rhs Dimension object to compare
///
/// \return true if static dimensions are equal and dynamic dimensions have equal labels else false
/// \return true if static dimensions are equal and dynamic dimensions have equal symbols else false
TRANSFORMATIONS_API bool dims_are_equal(const ov::Dimension& lhs, const ov::Dimension& rhs);

} // namespace util
Expand Down
Loading
Loading