Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbol Tracking API updated and made public #23136

Merged
merged 8 commits into from
Mar 25, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class ov::pass::FindBatch : public ov::pass::ModelPass {
};

namespace ov {
class DimensionTracker;

namespace batch_util {
void mark_batch(const std::shared_ptr<ov::opset1::Parameter>& parameter,
Expand All @@ -43,7 +42,7 @@ void mark_no_batch(const std::shared_ptr<ov::opset1::Parameter>& parameter, P2Bt
void mark_layout_independent_batch(const std::shared_ptr<ov::opset1::Parameter>& parameter,
const std::shared_ptr<ov::Node>& result,
P2Btype& map);
void mark_with_unique_dimension_labels(const std::shared_ptr<Model>& m, const ov::DimensionTracker& dt);
void mark_with_unique_dimension_labels(const std::shared_ptr<Model>& m, const std::shared_ptr<ov::LabelTable>& dt);
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
void restore_original_dimensions(
const std::shared_ptr<ov::Model>& model,
const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ov::pass::SymbolicPropagation : public ov::pass::ModelPass {
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;

private:
std::shared_ptr<ov::TableOfEquivalence> m_te;
std::shared_ptr<ov::LabelTable> m_te;
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TRANSFORMATIONS_API bool get_labels(const ov::Output<ov::Node>& output, ov::Tens
TRANSFORMATIONS_API bool are_unique_and_equal_labels(const ov::TensorLabel& lhs, const ov::TensorLabel& rhs);

/// \brief Compares dimensions: if dimensions are static compares values of dimensions, if dimensions are dynamic
/// compares their respective labels using TableOfEquivalence
/// compares their respective labels using LabelTable
///
/// \param lhs Dimension object to compare
/// \param rhs Dimension object to compare
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <vector>

#include "itt.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/label_table.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
Expand All @@ -22,12 +22,11 @@
#include "openvino/op/shape_of.hpp"

void ov::batch_util::mark_with_unique_dimension_labels(const std::shared_ptr<ov::Model>& m,
const ov::DimensionTracker& dt) {
ov::label_t i = 1;
const std::shared_ptr<ov::LabelTable>& te) {
for (auto& parameter : m->get_parameters()) {
ov::PartialShape new_shape = ov::PartialShape::dynamic(parameter->get_partial_shape().rank());
for (auto& dim : new_shape)
dt.set_up_for_tracking(dim, i++);
te->set_up_for_tracking(dim);
parameter->set_partial_shape(new_shape);
}
m->validate_nodes_and_infer_types();
Expand All @@ -41,23 +40,23 @@ void ov::batch_util::mark_batch(const std::shared_ptr<ov::op::v0::Parameter>& pa
std::unordered_set<ov::label_t> intersection_in_all_three_sources_of_batch;
auto mapped_batches = map[parameter];
for (auto& dim : shape) {
const auto& dim_label = ov::DimensionTracker::get_label(dim);
const auto& dim_label = dim.get_label();
if (batches.count(dim_label) && mapped_batches.count(dim_label)) {
intersection_in_all_three_sources_of_batch.insert(dim_label);
} else {
ov::DimensionTracker::reset_tracking_info(dim);
ov::LabelTable::reset_tracking_info(dim);
}
}
} else {
// two cases possible:
// 1) It is our first time marking batch for this node
// 2) This node was marked as 'no_batch' previously. 'no_batch' has higher priority, batch won't be set
for (auto& dim : shape) {
const auto& dim_label = ov::DimensionTracker::get_label(dim);
const auto& dim_label = dim.get_label();
if (batches.count(dim_label)) { // this is one of the batches
map[parameter].insert(dim_label);
} else {
ov::DimensionTracker::reset_tracking_info(dim);
ov::LabelTable::reset_tracking_info(dim);
}
}
}
Expand All @@ -71,10 +70,10 @@ void ov::batch_util::mark_layout_independent_batch(const std::shared_ptr<ov::op:
TensorLabel p_labels, r_labels;

for (const auto& dim : result->get_output_partial_shape(0))
if (const auto& label = ov::DimensionTracker::get_label(dim))
if (const auto& label = dim.get_label())
r_labels.push_back(label);
for (const auto& dim : parameter->get_partial_shape()) {
if (const auto& label = ov::DimensionTracker::get_label(dim)) {
if (const auto& label = dim.get_label()) {
if (std::find(r_labels.begin(), r_labels.end(), label) != r_labels.end()) {
mark_batch(parameter, map, std::unordered_set<label_t>{label});
return;
Expand All @@ -90,7 +89,7 @@ void ov::batch_util::mark_no_batch(const std::shared_ptr<ov::op::v0::Parameter>&
map.erase(parameter);
auto& shape = parameter->get_partial_shape();
for (auto& dim : shape)
ov::DimensionTracker::reset_tracking_info(dim);
ov::LabelTable::reset_tracking_info(dim);
parameter->set_partial_shape(shape);
parameter->validate_and_infer_types();
}
Expand Down Expand Up @@ -123,7 +122,7 @@ P2Btype ov::batch_util::find_batch(const std::shared_ptr<ov::Model>& f) {
if (type_input_port_batch_index.count(curr_node->get_type_info())) {
auto batch_placement = type_input_port_batch_index[curr_node->get_type_info()];
const auto& shape = curr_node->input_value(batch_placement.first).get_partial_shape();
const auto& batch_dim_label = ov::DimensionTracker::get_label(shape[batch_placement.second]);
const auto& batch_dim_label = shape[batch_placement.second].get_label();
if (batch_dim_label == 0)
mark_no_batch(parameter, parameter_to_batch_labels);
else
Expand All @@ -135,7 +134,7 @@ P2Btype ov::batch_util::find_batch(const std::shared_ptr<ov::Model>& f) {
for (const auto& output : curr_node->outputs()) {
const auto& output_shape = output.get_partial_shape();
bool name_stays = std::any_of(output_shape.cbegin(), output_shape.cend(), [](const Dimension& d) {
return ov::DimensionTracker::get_label(d) != 0;
return d.get_label() != 0;
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
});
all_outputs_labeled &= name_stays;
}
Expand Down Expand Up @@ -180,11 +179,11 @@ void ov::batch_util::restore_original_dimensions(
OPENVINO_ASSERT(batch_marked_shape.size() == original_shape.size());

for (size_t n = 0; n < batch_marked_shape.size(); ++n) {
if (const auto& label = ov::DimensionTracker::get_label(batch_marked_shape[n])) {
if (const auto& label = batch_marked_shape[n].get_label()) {
if (leave_batch_dynamic)
original_shape[n] = Dimension::dynamic();
if (!clear_labels)
ov::DimensionTracker::set_label(original_shape[n], label);
original_shape[n].set_label(label);
}
}
item.first->set_partial_shape(original_shape);
Expand All @@ -203,9 +202,9 @@ void ov::batch_util::restore_original_dimensions(
auto labeled_rank = labeled_shape.rank(), current_rank = current_shape.rank();
if (labeled_rank.is_static() && current_rank.is_static() && labeled_rank == current_rank) {
for (size_t i = 0; i < labeled_shape.size(); ++i) {
auto label = ov::DimensionTracker::get_label(labeled_shape[i]);
auto label = labeled_shape[i].get_label();
if (label != ov::no_label)
ov::DimensionTracker::set_label(current_shape[i], label);
current_shape[i].set_label(label);
}
item.first->set_output_type(0, item.first->get_element_type(), current_shape);
}
Expand All @@ -222,7 +221,7 @@ bool ov::batch_util::check_batch_tracks_through_all_the_nodes(const std::shared_
bool name_stays = false;
bool others_are_static = true;
for (const auto& dim : input_shape)
if (ov::DimensionTracker::get_label(dim) == 0)
if (dim.get_label() == 0)
others_are_static = others_are_static && dim.is_static();
else
name_stays = true;
Expand All @@ -234,7 +233,7 @@ bool ov::batch_util::check_batch_tracks_through_all_the_nodes(const std::shared_
bool name_stays = false;
bool others_are_static = true;
for (const auto& dim : output_shape)
if (ov::DimensionTracker::get_label(dim) == 0)
if (dim.get_label() == 0)
others_are_static = others_are_static && dim.is_static();
else
name_stays = true;
Expand All @@ -250,7 +249,7 @@ bool ov::batch_util::check_batch_tracks_through_all_the_nodes(const std::shared_
for (const auto& result : results) {
const auto& input_shape = result->get_input_partial_shape(0);
bool name_stays = std::any_of(input_shape.cbegin(), input_shape.cend(), [](const ov::Dimension& d) {
return ov::DimensionTracker::get_label(d);
return d.get_label();
});
failed_to_propagate_batch |= !name_stays;
}
Expand Down Expand Up @@ -297,8 +296,7 @@ std::map<std::shared_ptr<ov::op::v0::Parameter>, ov::PartialShape> collect_origi

bool ov::pass::FindBatch::run_on_model(const std::shared_ptr<ov::Model>& m) {
RUN_ON_MODEL_SCOPE(FindBatch);
auto te = std::make_shared<ov::TableOfEquivalence>();
ov::DimensionTracker dt(te);
auto te = std::make_shared<ov::LabelTable>();

bool model_has_changed = false;
if (detach_do)
Expand All @@ -308,7 +306,7 @@ bool ov::pass::FindBatch::run_on_model(const std::shared_ptr<ov::Model>& m) {
if (parameter_to_shape.empty())
return model_has_changed;

ov::batch_util::mark_with_unique_dimension_labels(m, dt);
ov::batch_util::mark_with_unique_dimension_labels(m, te);

ov::batch_util::find_batch(m);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <memory>

#include "itt.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/dimension.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
Expand Down Expand Up @@ -52,7 +52,7 @@ shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
continue;
for (ov::Dimension& n : pshape) {
n = ov::Dimension::dynamic();
ov::DimensionTracker::set_label(n, label++);
n.set_label(label++);
}
parameter->set_partial_shape(pshape);
}
Expand All @@ -62,7 +62,7 @@ shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
// if lstm first input has undefined rank or if tracked label is zero -- we failed to track batch dimension
// returning body to initial state
if (lstm_cell->get_input_partial_shape(0).rank().is_dynamic() ||
ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]) == 0) {
lstm_cell->get_input_partial_shape(0)[0].get_label() == ov::no_label) {
for (auto& item : original_shapes)
item.first->set_partial_shape(item.second);
body->validate_nodes_and_infer_types();
Expand All @@ -73,13 +73,13 @@ shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
shared_ptr<ov::op::v0::Parameter> batch_delivering_parameter;
size_t index_of_batch_dim = 0;

ov::label_t batch_label = ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]);
ov::label_t batch_label = lstm_cell->get_input_partial_shape(0)[0].get_label();
for (auto& parameter : body->get_parameters()) {
auto pshape = parameter->get_partial_shape();
if (pshape.rank().is_dynamic())
continue;
for (size_t i = 0; i < pshape.size(); ++i) {
if (ov::DimensionTracker::get_label(pshape[i]) == batch_label) {
if (pshape[i].get_label() == batch_label) {
batch_delivering_parameter = parameter;
index_of_batch_dim = i;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "transformations/symbolic_transformations/chained_maximum.hpp"

#include "itt.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/dimension.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/symbolic_transformations/utils.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "transformations/symbolic_transformations/dereshape_matmul.hpp"

#include "itt.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/dimension.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "itt.hpp"
#include "openvino/core/bound_evaluation_util.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/label_table.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
Expand Down Expand Up @@ -41,10 +41,10 @@ void apply_table_of_equivalence_on_model(const std::shared_ptr<ov::Model>& m, co
for (auto& d : shape) {
if (d.is_static())
continue;
auto label = ov::DimensionTracker::get_label(d);
auto label = d.get_label();
update_label(table, label);
if (label != ov::no_label)
ov::DimensionTracker::set_label(d, label);
d.set_label(label);
}
op->set_output_type(output.get_index(), output.get_element_type(), shape);
// value relabeling
Expand Down Expand Up @@ -77,7 +77,7 @@ int64_t get_idx_of_label_in_source(const ov::Output<ov::Node>& source, const ov:
if (rank.is_dynamic())
return idx;
for (int64_t i = 0; i < rank.get_length(); ++i) {
auto l = ov::DimensionTracker::get_label(pshape[i]);
auto l = pshape[i].get_label();
if (l == label) {
idx = i;
break;
Expand Down Expand Up @@ -175,13 +175,13 @@ ov::Output<ov::Node> alternative_source_from_concat_input_sources(const LTS_map&
const auto& lhs_pshape = concat->get_input_partial_shape(0);
const auto& rhs_pshape = concat->get_input_partial_shape(1);
if (lhs_pshape.rank().is_static() && rhs_pshape.rank().is_static()) {
auto lhs_label = ov::DimensionTracker::get_label(lhs_pshape[idx]);
auto lhs_label = lhs_pshape[idx].get_label();
auto lhs_alternative = get_alternative_source_from_value_or_shape_source(label_shape_source,
lhs_label,
original_output,
label_value_source);

auto rhs_label = ov::DimensionTracker::get_label(rhs_pshape[idx]);
auto rhs_label = rhs_pshape[idx].get_label();
auto rhs_alternative = get_alternative_source_from_value_or_shape_source(label_shape_source,
rhs_label,
original_output,
Expand Down Expand Up @@ -228,7 +228,7 @@ void save_shape_sources(const ov::Output<ov::Node>& output, LTS_map& label_shape
for (const auto& d : output.get_partial_shape()) {
if (d.is_static())
continue;
auto label = ov::DimensionTracker::get_label(d);
auto label = d.get_label();
if (label == ov::no_label || label_shape_source.count(label))
continue;
label_shape_source[label] = output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "compare.hpp"
#include "itt.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/dimension.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/op/shape_of.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "compare.hpp"
#include "itt.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/symbolic_transformations/utils.hpp"
#include "transformations/utils/utils.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "itt.hpp"
#include "openvino/core/descriptor_tensor.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/label_table.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/util/symbolic_info.hpp"
Expand All @@ -30,16 +30,16 @@ using namespace ov::pass;
using namespace ov::symbol::util;

namespace {
void symbolic_set_up_for_shape(ov::DimensionTracker& dt, ov::PartialShape& shape) {
void symbolic_set_up_for_shape(const std::shared_ptr<ov::LabelTable>& te, ov::PartialShape& shape) {
if (shape.rank().is_dynamic())
return;
for (auto& d : shape) {
bool is_static = d.is_static(), has_label = ov::DimensionTracker::has_label(d);
bool is_static = d.is_static(), has_label = d.has_label();
if (is_static && has_label)
dt.reset_tracking_info(d); // remove labels from static dims on shapes to reduce label clutter
te->reset_tracking_info(d); // remove labels from static dims on shapes to reduce label clutter
if (is_static || has_label)
continue;
dt.set_up_for_tracking(d);
te->set_up_for_tracking(d);
}
}

Expand Down Expand Up @@ -81,23 +81,22 @@ void special_case_range_label_propagation(const std::shared_ptr<ov::Node>& node)
auto add_in1_label = add_in1_labels[0];

if (add_in0_label == start_label)
ov::DimensionTracker::set_label(output_shape[0], add_in1_label);
output_shape[0].set_label(add_in1_label);
else if (add_in1_label == start_label)
ov::DimensionTracker::set_label(output_shape[0], add_in0_label);
output_shape[0].set_label(add_in0_label);
node->set_output_type(0, node->get_output_element_type(0), output_shape);
}
} // namespace

ov::pass::SymbolicPropagation::SymbolicPropagation() {
m_te = std::make_shared<ov::TableOfEquivalence>();
m_te = std::make_shared<ov::LabelTable>();
}

bool ov::pass::SymbolicPropagation::run_on_model(const std::shared_ptr<ov::Model>& m) {
RUN_ON_MODEL_SCOPE(SymbolicPropagation);

auto te = m_te;
ov::set_up_symbolic_info(m, te);
ov::DimensionTracker dt(te);

for (const auto& op : m->get_ordered_ops()) {
// since we disable invalidation with the following two lines, we have to invalidate manually here
Expand All @@ -117,7 +116,7 @@ bool ov::pass::SymbolicPropagation::run_on_model(const std::shared_ptr<ov::Model

for (auto& output : op->outputs()) {
auto shape = output.get_partial_shape();
symbolic_set_up_for_shape(dt, shape);
symbolic_set_up_for_shape(te, shape);
ov::descriptor::set_tensor_type(output.get_tensor(), output.get_element_type(), shape);
}
}
Expand Down
Loading
Loading