Skip to content

Commit

Permalink
Symbol Tracking API updated and made public (openvinotoolkit#23136)
Browse files Browse the repository at this point in the history
### Details:
- dev_api `ov::DimensionTracker` and `ov::TableOfEquivalence` classes
deleted, logic moved to `ov::Symbol` which is now stored by
`ov::Dimension`
- new implementation moves responsibility to store and report relations
between Symbols directly to the Symbol object. Hence, there is no need
for `ov::TableOfEquivalence` and no need for synchronization point
anymore.
- Equivalence is being tracked by using
[Disjoint-set_data_structure](https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
which uses less memory than previous implementation.


![image](https://github.com/openvinotoolkit/openvino/assets/55839243/f1266f32-976d-44f9-a6ea-cd04dce07407)


![image](https://github.com/openvinotoolkit/openvino/assets/55839243/3108d1ad-0d30-4041-aa93-c4de1f1fb979)

### Tickets:
 - *CVS-133123*
  • Loading branch information
jane-intel authored and bbielawx committed Apr 12, 2024
1 parent e841688 commit d707a8a
Show file tree
Hide file tree
Showing 187 changed files with 3,345 additions and 3,404 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2021 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,7 +10,8 @@
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

using P2Btype = std::unordered_map<std::shared_ptr<ov::opset1::Parameter>, std::unordered_set<ov::label_t>>;
using P2Btype =
std::unordered_map<std::shared_ptr<ov::opset1::Parameter>, std::unordered_set<std::shared_ptr<ov::Symbol>>>;

namespace ov {
namespace pass {
Expand All @@ -33,22 +34,21 @@ class ov::pass::FindBatch : public ov::pass::ModelPass {
};

namespace ov {
class DimensionTracker;

namespace batch_util {
void mark_batch(const std::shared_ptr<ov::opset1::Parameter>& parameter,
P2Btype& map,
const std::unordered_set<label_t>& batches);
const std::unordered_set<std::shared_ptr<Symbol>>& batches);
void mark_no_batch(const std::shared_ptr<ov::opset1::Parameter>& parameter, P2Btype& map);
void mark_layout_independent_batch(const std::shared_ptr<ov::opset1::Parameter>& parameter,
const std::shared_ptr<ov::Node>& result,
P2Btype& map);
void mark_with_unique_dimension_labels(const std::shared_ptr<Model>& m, const ov::DimensionTracker& dt);
void mark_with_unique_dimension_symbols(const std::shared_ptr<Model>& m);
void restore_original_dimensions(
const std::shared_ptr<ov::Model>& model,
const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape,
bool leave_batch_dynamic = true,
bool clear_labels = false);
bool clear_symbols = false);
bool check_batch_tracks_through_all_the_nodes(const std::shared_ptr<ov::Model>& m);
P2Btype find_batch(const std::shared_ptr<ov::Model>& m);
bool detach_detection_output(const std::shared_ptr<ov::Model>& f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TRANSFORMATIONS_API ChainedMaximumOptimization;

/**
* @ingroup ov_transformation_common_api
* @brief Optimizes graphs based on value labels / symbols
* @brief Optimizes graphs based on value symbols
* Maximum(Maximum(A, B), B) -> Maximum(A, B)
* Maximum(Maximum(A, B), A) -> Maximum(A, B)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TRANSFORMATIONS_API DeReshapeFullyConnected;

/**
* @ingroup ov_transformation_common_api
* @brief Transformation uses symbol / label information to optimize out Reshape operations surrounding MatMul.
* @brief Transformation uses symbol information to optimize out Reshape operations surrounding MatMul.
* It checks that surrounding Reshapes are only manipulating with batch dimensions of tensor in a do-undo kind of way.
*
* Example:
Expand Down Expand Up @@ -46,9 +46,9 @@ class TRANSFORMATIONS_API DeReshapeFullyConnected;
* Binary Elementwise Arithmetic operation without second input scalar restriction.
* MatMul -[-> BEA -]-> Reshape
* this pattern variation is only applicable for the case when input reshapes are 4D -> 3D and output reshape is 3D ->
* 4D. Additionally, shape labels on output of MatMul should be equal to the input shape labels of the last Reshape,
* 4D. Additionally, shape symbols on output of MatMul should be equal to the input shape symbols of the last Reshape,
* meaning that this Binary Elementwise Arithmetic doesn't perform any broadcasting of input coming from MatMul -- only
* other input may be broadcasted to the MatMul input of this BEA. This effect (equality of MatMul output shape labels
* other input may be broadcasted to the MatMul input of this BEA. This effect (equality of MatMul output shape symbols
* and output shape of BEA) is being handled by LabelResolvingThroughSelect transformation in the particular models
* that this variation targets.
*
Expand All @@ -68,7 +68,7 @@ class ov::pass::DeReshapeMatMul : public ov::pass::MatcherPass {

/**
* @ingroup ov_transformation_common_api
* @brief Transformation uses symbol / label information to optimize out Reshape operations surrounding special cases of
* @brief Transformation uses symbol information to optimize out Reshape operations surrounding special cases of
* MatMul. It checks that surrounding Reshapes are only manipulating with batch dimensions of tensor in a do-undo kind
* of way. The difference with previous optimization is that this case has Reshape only on one input of MatMul and the
* other input is strictly 2D. Such MatMuls are also called FullyConnected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@

namespace ov {
namespace pass {
class TRANSFORMATIONS_API ApplyTableOfEquivalence;
class TRANSFORMATIONS_API OptimizeLabelsUsedAsValues;
class TRANSFORMATIONS_API ApplySymbolEquivalence;
class TRANSFORMATIONS_API OptimizeSymbolsUsedAsValues;
} // namespace pass
} // namespace ov

/**
* @ingroup ov_transformation_common_api
* @brief Resets symbols / labels on output shapes and values according to table of symbol / label equivalence. It
* allows to reduce number of labels used in the model and to disambiguate label values.
* @brief Resets symbols on output shapes and values according to symbol equivalence. It
* allows to reduce number of labels used in the model and to disambiguate symbol values.
*/
class ov::pass::ApplyTableOfEquivalence : public ov::pass::ModelPass {
class ov::pass::ApplySymbolEquivalence : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ApplyTableOfEquivalence", "0");
OPENVINO_RTTI("ApplySymbolEquivalence", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};

/**
* @ingroup ov_transformation_common_api
* @brief Collects sources where each symbol / label initially appeared (on shape or shape sub-graph) and attaches all
* @brief Collects sources where each symbol initially appeared (on shape or shape sub-graph) and attaches all
* value usages of this label to this initial source
*/
class ov::pass::OptimizeLabelsUsedAsValues : public ov::pass::ModelPass {
class ov::pass::OptimizeSymbolsUsedAsValues : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("OptimizeLabelsUsedAsValues", "0");
OPENVINO_RTTI("OptimizeSymbolsUsedAsValues", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@ class ov::pass::SymbolicOptimizations : public ov::pass::ModelPass {
class ov::pass::SymbolicPropagation : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("SymbolicPropagation");
SymbolicPropagation();
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;

private:
std::shared_ptr<ov::TableOfEquivalence> m_te;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,37 @@ namespace ov {
namespace symbol {
namespace util {

/// \brief Collects labels from shape. Labels of static dimensions are guaranteed to be ov::no_labels
/// \brief Collects symbols from shape. Symbols of static dimensions are guaranteed to be nullptr
///
/// \param shape Shape object to collect labels from
/// \param labels TensorLabel object to collect labels to
/// \param shape Shape object to collect symbols from
/// \param symbols TensorSymbol object to collect symbols to
///
/// \return Status of collecting the labels (false if rank is static else true)
TRANSFORMATIONS_API bool get_labels(const ov::PartialShape& shape, ov::TensorLabel& labels);
/// \return Status of collecting the symbols (false if rank is static else true)
TRANSFORMATIONS_API bool get_symbols(const ov::PartialShape& shape, ov::TensorSymbol& symbols);

/// \brief Collects labels from tensor of Output object
/// \brief Collects symbols from tensor of Output object
///
/// \param output Output object to collect labels from
/// \param labels TensorLabel object to collect labels to
/// \param output Output object to collect symbols from
/// \param symbols TensorSymbol object to collect symbols to
///
/// \return Status of collecting the labels (false if tensor has no labels else true)
TRANSFORMATIONS_API bool get_labels(const ov::Output<ov::Node>& output, ov::TensorLabel& labels);
/// \return Status of collecting the symbols (false if tensor has no symbols else true)
TRANSFORMATIONS_API bool get_symbols(const ov::Output<ov::Node>& output, ov::TensorSymbol& symbols);

/// \brief Compares
///
/// \param lhs TensorLabel object to compare
/// \param rhs TensorLabel object to compare
/// \param lhs TensorSymbol object to compare
/// \param rhs TensorSymbol object to compare
///
/// \return true if labels are unique and equal between lhs and rhs else false
TRANSFORMATIONS_API bool are_unique_and_equal_labels(const ov::TensorLabel& lhs, const ov::TensorLabel& rhs);
/// \return true if symbols are unique and equal between lhs and rhs else false
TRANSFORMATIONS_API bool are_unique_and_equal_symbols(const ov::TensorSymbol& lhs, const ov::TensorSymbol& rhs);

/// \brief Compares dimensions: if dimensions are static compares values of dimensions, if dimensions are dynamic
/// compares their respective labels using TableOfEquivalence
/// compares their respective symbols
///
/// \param lhs Dimension object to compare
/// \param rhs Dimension object to compare
///
/// \return true if static dimensions are equal and dynamic dimensions have equal labels else false
/// \return true if static dimensions are equal and dynamic dimensions have equal symbols else false
TRANSFORMATIONS_API bool dims_are_equal(const ov::Dimension& lhs, const ov::Dimension& rhs);

} // namespace util
Expand Down
Loading

0 comments on commit d707a8a

Please sign in to comment.