Skip to content

Commit

Permalink
Support for dynamic shapes in ReadValue/Assign ops and MakeStateful t…
Browse files Browse the repository at this point in the history
…ransformation
  • Loading branch information
itikhono committed Oct 11, 2023
1 parent aa7405f commit f8727d2
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 29 deletions.
7 changes: 7 additions & 0 deletions src/core/include/openvino/op/read_value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ class OPENVINO_API ReadValue : public util::ReadValueBase {
OPENVINO_OP("ReadValue", "opset6", util::ReadValueBase);
ReadValue() = default;


/// \brief Constructs a ReadValue operation.
///
/// \param variable Class for storing and synchronizing element types, shapes and
/// identifiers between pairs of Assign/ReadValue nodes.
ReadValue(const std::shared_ptr<util::Variable>& variable);

/// \brief Constructs a ReadValue operation.
///
/// \param init_value Node that produces the input tensor.
Expand Down
23 changes: 21 additions & 2 deletions src/core/src/op/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,28 @@ Assign::Assign(const Output<Node>& new_value, const std::shared_ptr<util::Variab

void Assign::validate_and_infer_types() {
OV_OP_SCOPE(v6_Assign_validate_and_infer_types);
m_variable->update({get_input_partial_shape(0), get_input_element_type(0), m_variable->get_info().variable_id});
auto variable_info = m_variable->get_info();
auto variable_type = variable_info.data_type;
auto variable_shape = variable_info.data_shape;

auto input_type = get_input_element_type(0);
auto input_shape = get_input_partial_shape(0);
bool compatible_type = variable_type.is_dynamic() || input_type == variable_type;
bool compatible_shape = variable_shape.rank().relaxes(input_shape.rank());

if (compatible_shape && input_shape.rank().is_static() && variable_shape.rank().is_static()) {
OPENVINO_ASSERT(input_shape.rank().get_length() == variable_shape.rank().get_length(),
"Ranks of initial_shape and variable_shape do not match.");
for (size_t i = 0; i < variable_shape.rank().get_length(); ++i) {
compatible_shape = compatible_shape && variable_shape[i].relaxes(input_shape[i]);
}
}

set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
OPENVINO_ASSERT(compatible_shape, "The shape specified in the Variable doesn't match the shape "
"inferred from the initializing subgraph.");
OPENVINO_ASSERT(compatible_type, "The type specified in the Variable doesn't match the type "
"inferred from the initializing subgraph.");
set_output_type(0, input_type, input_shape);
}

std::shared_ptr<Node> Assign::clone_with_new_inputs(const OutputVector& new_args) const {
Expand Down
51 changes: 40 additions & 11 deletions src/core/src/op/read_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ bool op::v3::ReadValue::visit_attributes(AttributeVisitor& visitor) {
return true;
}

op::v6::ReadValue::ReadValue(const shared_ptr<Variable>& variable) {
m_variable = variable;
constructor_validate_and_infer_types();
}

op::v6::ReadValue::ReadValue(const Output<Node>& init_value, const shared_ptr<Variable>& variable)
: ReadValueBase({init_value}) {
m_variable = variable;
Expand All @@ -50,19 +55,39 @@ op::v6::ReadValue::ReadValue(const Output<Node>& init_value, const shared_ptr<Va

void op::v6::ReadValue::validate_and_infer_types() {
OV_OP_SCOPE(v6_ReadValue_validate_and_infer_types);
const auto arg_t = get_input_element_type(0);
const auto& input_shape = get_input_partial_shape(0);

OPENVINO_ASSERT(m_variable, "Variable is not initialized.");
VariableInfo var_info = {input_shape, element::dynamic, m_variable->get_info().variable_id};
NODE_VALIDATION_CHECK(this,
element::Type::merge(var_info.data_type, m_variable->get_info().data_type, arg_t),
"Variables types are inconsistent.");
NODE_VALIDATION_CHECK(this,
ov::PartialShape::merge_into(var_info.data_shape, m_variable->get_info().data_shape),
"Variable shape and output shape are inconsistent.");
m_variable->update(var_info);
set_output_type(0, arg_t, input_shape);
auto variable_info = m_variable->get_info();
auto variable_type = variable_info.data_type;
auto variable_shape = variable_info.data_shape;

// If no inputs provided, it means this ReadValue doesn't have initial subgraph. This is valid.
if (get_input_size() > 0) {
const auto initial_type = get_input_element_type(0);
const auto &initial_shape = get_input_partial_shape(0);

// Variable shape/type determines a permissible range of values for shape/type inferred from initial_subgraph.
// If initial_subgraph is set, then we need to check that shape/type inferred from initial_subgraph
// is within the permissible range.

bool compatible_type = variable_type.is_dynamic() || initial_type == variable_type;
bool compatible_shape = variable_shape.rank().relaxes(initial_shape.rank());

if (compatible_shape && initial_shape.rank().is_static() && variable_shape.rank().is_static()) {
OPENVINO_ASSERT(initial_shape.rank().get_length() == variable_shape.rank().get_length(),
"Ranks of initial_shape and variable_shape do not match.");
for (size_t i = 0; i < variable_shape.rank().get_length(); ++i) {
compatible_shape = compatible_shape
&& variable_shape[i].relaxes(initial_shape[i]);
}
}
OPENVINO_ASSERT(compatible_shape, "The shape specified in the Variable doesn't match the shape "
"inferred from the initializing subgraph.");
OPENVINO_ASSERT(compatible_type, "The type specified in the Variable doesn't match the type "
"inferred from the initializing subgraph.");
}

set_output_type(0, variable_type, variable_shape);
}

shared_ptr<Node> op::v6::ReadValue::clone_with_new_inputs(const OutputVector& new_args) const {
Expand All @@ -74,6 +99,10 @@ shared_ptr<Node> op::v6::ReadValue::clone_with_new_inputs(const OutputVector& ne
bool op::v6::ReadValue::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v6_ReadValue_visit_attributes);
visitor.on_attribute("variable_id", m_variable);

auto variable_info = m_variable->get_info();
visitor.on_attribute("variable_shape", variable_info.data_shape);
visitor.on_attribute("variable_type", variable_info.data_type);
return true;
}

Expand Down
10 changes: 2 additions & 8 deletions src/core/src/pass/make_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,16 @@ bool ov::pass::MakeStateful::run_on_model(const std::shared_ptr<ov::Model>& f) {
const auto& param = m_param_res_pairs[i].first;
const auto& res = m_param_res_pairs[i].second;

OPENVINO_ASSERT(param->get_partial_shape().is_static(),
"Shape of Parameter ",
param->get_friendly_name(),
" must be static. MakeStateful transformation doesn't support dynamic shapes.");

// Create Variable
std::string var_name = variable_names[i];
auto variable = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{param->get_shape(), param->get_element_type(), var_name});
variables.push_back(variable);

// Create ReadValue
auto const_zero = std::make_shared<ov::op::v0::Constant>(param->get_element_type(), param->get_shape(), 0);
auto read_val = std::make_shared<ov::op::v6::ReadValue>(const_zero, variable);
auto read_val = std::make_shared<ov::op::v6::ReadValue>(variable);
replace_node(param, read_val);
ov::copy_runtime_info(param, {read_val, const_zero});
ov::copy_runtime_info(param, read_val);

// Create Assign
auto assign = std::make_shared<ov::op::v6::Assign>(res->input_value(0), variable);
Expand Down
95 changes: 87 additions & 8 deletions src/core/tests/type_prop/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

#include "common_test_utils/type_prop.hpp"
#include "openvino/core/model.hpp"
#include "openvino/op/read_value.hpp"
#include "openvino/op/assign.hpp"
#include "openvino/op/util/variable.hpp"

using namespace std;
using namespace ov;

TEST(type_prop, assign_variable_not_found) {
auto A = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2, 64, 64});
Expand All @@ -35,26 +36,26 @@ TEST(type_prop, assign_deduce) {
ASSERT_EQ(assign->get_shape(), (ov::Shape{1, 2, 64, 64}));
}

TEST(type_prop, assign_read_value_new_shape) {
TEST(type_prop, assign_set_new_shape_allowed_range) {
auto input = make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::Shape{4, 3, 2, 1});

auto variable = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape::dynamic(), ov::element::dynamic, "ID"});
auto read_value = make_shared<ov::op::v6::ReadValue>(input, variable);
auto assign = make_shared<ov::op::v6::Assign>(read_value, variable);

ASSERT_EQ(assign->get_element_type(), ov::element::f16);
ASSERT_EQ(assign->get_shape(), (ov::Shape{4, 3, 2, 1}));
ASSERT_EQ(assign->get_element_type(), ov::element::dynamic);
ASSERT_EQ(assign->get_output_partial_shape(0), (ov::PartialShape::dynamic()));

auto m = std::make_shared<ov::Model>(ov::ResultVector{}, ov::SinkVector{assign}, ov::ParameterVector{input});

input->set_partial_shape({3, {4, 5}, 8});
m->validate_nodes_and_infer_types();

ASSERT_EQ(assign->get_element_type(), ov::element::f16);
ASSERT_EQ(assign->get_output_partial_shape(0), (ov::PartialShape{3, {4, 5}, 8}));
ASSERT_EQ(variable->get_info().data_type, ov::element::f16);
ASSERT_EQ(variable->get_info().data_shape, (ov::PartialShape{3, {4, 5}, 8}));
ASSERT_EQ(assign->get_element_type(), ov::element::dynamic);
ASSERT_EQ(assign->get_output_partial_shape(0), (ov::PartialShape::dynamic()));
ASSERT_EQ(variable->get_info().data_type, ov::element::dynamic);
ASSERT_EQ(variable->get_info().data_shape, (ov::PartialShape::dynamic()));
}

TEST(type_prop, variable_comparison) {
Expand All @@ -78,3 +79,81 @@ TEST(type_prop, variable_comparison) {
ASSERT_FALSE(variable1->get_info() == variable4->get_info());
ASSERT_FALSE(variable1->get_info() == variable5->get_info());
}

TEST(type_prop, assign_v6_static_shape_match) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 64, 64});
auto variable = std::make_shared<op::util::Variable>(op::util::VariableInfo{PartialShape{1, 2, 64, 64}, element::f32, "variable_id"});
std::shared_ptr<ov::op::v6::Assign> assign;
EXPECT_NO_THROW(assign = std::make_shared<ov::op::v6::Assign>(input, variable));

ASSERT_EQ(assign->get_element_type(), element::f32);
ASSERT_EQ(assign->get_shape(), (Shape{1, 2, 64, 64}));
}

TEST(type_prop, assign_v6_static_shapes_do_not_match) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2, 64, 64});
auto variable_info = op::util::VariableInfo{PartialShape{1, 2, 64, 64}, element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);
std::shared_ptr<ov::op::v6::Assign> assign;
EXPECT_ANY_THROW(assign = std::make_shared<ov::op::v6::Assign>(input, variable));
}

TEST(type_prop, assign_v6_static_types_do_not_match) {
auto input = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 2, 64, 64});
auto variable_info = op::util::VariableInfo{PartialShape{1, 2, 64, 64}, element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);
std::shared_ptr<ov::op::v6::Assign> assign;
EXPECT_ANY_THROW(assign = std::make_shared<ov::op::v6::Assign>(input, variable));
}

TEST(type_prop, assign_v6_dyn_shape_type_in_variable) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 64, 64});

auto variable_info = op::util::VariableInfo{PartialShape{Dimension::dynamic(), 2, Dimension::dynamic(), 64},
element::dynamic, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);

std::shared_ptr<ov::op::v6::Assign> assign;
EXPECT_NO_THROW(assign = std::make_shared<ov::op::v6::Assign>(input, variable));

ASSERT_EQ(assign->get_element_type(), element::f32);
ASSERT_EQ(assign->get_output_partial_shape(0), (PartialShape{1, 2, 64, 64}));
ASSERT_EQ(assign->get_variable_id(), "variable_id");
}

TEST(type_prop, assign_v6_init_shape_is_in_range) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, 2, 64, 64});

auto variable_info = op::util::VariableInfo{PartialShape{{1, 10}, {2, 5}, {64, 64}, 64},
element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);

std::shared_ptr<ov::op::v6::Assign> assign;
EXPECT_NO_THROW(assign = std::make_shared<ov::op::v6::Assign>(input, variable));

ASSERT_EQ(assign->get_element_type(), element::f32);
ASSERT_EQ(assign->get_output_partial_shape(0), (PartialShape{1, 2, 64, 64}));
ASSERT_EQ(assign->get_variable_id(), "variable_id");
}

TEST(type_prop, assign_v6_init_shape_is_not_in_range) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, 2, 64, 64});

auto variable_info = op::util::VariableInfo{PartialShape{{2, 5}, {2, 5}, 64, 64},
element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);

std::shared_ptr<ov::op::v6::Assign> assign;
EXPECT_ANY_THROW(assign = std::make_shared<ov::op::v6::Assign>(input, variable));
}

TEST(type_prop, assign_v6_init_shape_is_not_in_range_2) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{{1, 2}, 2, 64, 64});

auto variable_info = op::util::VariableInfo{PartialShape{1, 2, 64, 64},
element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);

std::shared_ptr<ov::op::v6::Assign> assign;
EXPECT_ANY_THROW(assign = std::make_shared<ov::op::v6::Assign>(input, variable));
}
79 changes: 79 additions & 0 deletions src/core/tests/type_prop/read_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,83 @@ TEST(type_prop, read_value_deduce) {

ASSERT_EQ(read_value->get_element_type(), element::f32);
ASSERT_EQ(read_value->get_shape(), (Shape{1, 2, 64, 64}));
ASSERT_EQ(read_value->get_variable_id(), "variable_id");
}

TEST(type_prop, read_value_v6_static_shape_match) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 64, 64});
auto variable = std::make_shared<op::util::Variable>(op::util::VariableInfo{PartialShape{1, 2, 64, 64}, element::f32, "variable_id"});
std::shared_ptr<ov::op::v6::ReadValue> read_value;
EXPECT_NO_THROW(read_value = std::make_shared<ov::op::v6::ReadValue>(input, variable));

ASSERT_EQ(read_value->get_element_type(), element::f32);
ASSERT_EQ(read_value->get_shape(), (Shape{1, 2, 64, 64}));
}

TEST(type_prop, read_value_v6_static_shapes_do_not_match) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2, 64, 64});
auto variable_info = op::util::VariableInfo{PartialShape{1, 2, 64, 64}, element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);
std::shared_ptr<ov::op::v6::ReadValue> read_value;
EXPECT_ANY_THROW(read_value = std::make_shared<ov::op::v6::ReadValue>(input, variable));
}

TEST(type_prop, read_value_v6_static_types_do_not_match) {
auto input = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 2, 64, 64});
auto variable_info = op::util::VariableInfo{PartialShape{1, 2, 64, 64}, element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);
std::shared_ptr<ov::op::v6::ReadValue> read_value;
EXPECT_ANY_THROW(read_value = std::make_shared<ov::op::v6::ReadValue>(input, variable));
}

TEST(type_prop, read_value_v6_dyn_shape_type_in_variable) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 64, 64});

auto variable_info = op::util::VariableInfo{PartialShape{Dimension::dynamic(), 2, Dimension::dynamic(), 64},
element::dynamic, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);

std::shared_ptr<ov::op::v6::ReadValue> read_value;
EXPECT_NO_THROW(read_value = std::make_shared<ov::op::v6::ReadValue>(input, variable));

ASSERT_EQ(read_value->get_element_type(), element::dynamic);
ASSERT_EQ(read_value->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 2, Dimension::dynamic(), 64}));
ASSERT_EQ(read_value->get_variable_id(), "variable_id");
}

TEST(type_prop, read_value_v6_init_shape_is_in_range) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, 2, 64, 64});

auto variable_info = op::util::VariableInfo{PartialShape{{1, 10}, {2, 5}, {64, 64}, 64},
element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);

std::shared_ptr<ov::op::v6::ReadValue> read_value;
EXPECT_NO_THROW(read_value = std::make_shared<ov::op::v6::ReadValue>(input, variable));

ASSERT_EQ(read_value->get_element_type(), element::f32);
ASSERT_EQ(read_value->get_output_partial_shape(0), (PartialShape{{1, 10}, {2, 5}, {64, 64}, 64}));
ASSERT_EQ(read_value->get_variable_id(), "variable_id");
}

TEST(type_prop, read_value_v6_init_shape_is_not_in_range) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, 2, 64, 64});

auto variable_info = op::util::VariableInfo{PartialShape{{2, 5}, {2, 5}, 64, 64},
element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);

std::shared_ptr<ov::op::v6::ReadValue> read_value;
EXPECT_ANY_THROW(read_value = std::make_shared<ov::op::v6::ReadValue>(input, variable));
}

TEST(type_prop, read_value_v6_init_shape_is_not_in_range_2) {
auto input = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{{1, 2}, 2, 64, 64});

auto variable_info = op::util::VariableInfo{PartialShape{1, 2, 64, 64},
element::f32, "variable_id"};
auto variable = std::make_shared<op::util::Variable>(variable_info);

std::shared_ptr<ov::op::v6::ReadValue> read_value;
EXPECT_ANY_THROW(read_value = std::make_shared<ov::op::v6::ReadValue>(input, variable));
}

0 comments on commit f8727d2

Please sign in to comment.