Skip to content

Commit

Permalink
TorchFX Dynamic Shapes Additional Support (openvinotoolkit#24773)
Browse files Browse the repository at this point in the history
### Details:
- Reshape dynamic input support updated to use Concat if there are
unkown dimensions
 - Dynamic shape support added for arange and full ops. 

### Tickets:
 - *ticket-id*
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- - - - - -

---------

Co-authored-by: ynimmaga <[email protected]>
Co-authored-by: Maxim Vafin <[email protected]>
  • Loading branch information
3 people authored Jul 9, 2024
1 parent 6f2f236 commit ac3cdae
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def _raw_input(self, index):
def num_of_outputs(self):
return len(self.outputs())

def output_list_size(self):
max_out_id = -1
for user in self.pt_module.users:
if "<built-in function getitem>" == str(user.target) and max_out_id < user.args[1]:
max_out_id = user.args[1]
return max_out_id + 1

def output(self, index):
return self.outputs()[index]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(self, options):
"torch.ops.aten.transpose.int": None,
"torch.ops.aten.tril.default": None,
"torch.ops.aten.tril_.default": None,
"torch.ops.aten.triu.default": None,
"torch.ops.aten.unbind.int": None,
"torch.ops.aten.unfold.default": None,
"torch.ops.aten.unsqueeze.default": None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
PYBIND11_OVERRIDE_PURE(size_t, TorchDecoder, num_of_outputs);
}

size_t output_list_size() const override {
PYBIND11_OVERRIDE_PURE(size_t, TorchDecoder, output_list_size);
}

const std::vector<size_t>& outputs() const override {
PYBIND11_OVERRIDE_PURE(const std::vector<size_t>&, TorchDecoder, outputs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class TorchDecoder : public IDecoder {
// TODO: use canonical name output_size
virtual size_t num_of_outputs() const = 0;

// If the node output is a list of getitem nodes, returns the size of the list
// If the node output is not a list of getitem nodes, returns 0
virtual size_t output_list_size() const = 0;

// Return a vector of output IDs
virtual const std::vector<size_t>& outputs() const = 0;

Expand Down
12 changes: 12 additions & 0 deletions src/frontends/pytorch/src/op/arange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/squeeze.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -108,6 +109,17 @@ OutputVector translate_arange_fx(const NodeContext& context) {
if (context.has_attribute("dtype")) {
dtype = context.get_attribute<element::Type>("dtype");
}
auto input_squeeze = [&context](ov::Output<Node> input) {
if (input.get_partial_shape().rank().is_dynamic() ||
(input.get_partial_shape().rank().is_static() && input.get_partial_shape().rank().get_length() == 1)) {
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
input = context.mark_node(std::make_shared<ov::op::v0::Squeeze>(input, zero));
}
return input;
};
start = input_squeeze(start);
end = input_squeeze(end);
step = input_squeeze(step);
auto range = context.mark_node(std::make_shared<v4::Range>(start, end, step, dtype));
if (!context.has_attribute("dtype")) {
range = context.mark_node(std::make_shared<v1::ConvertLike>(range, context.get_input(0)));
Expand Down
28 changes: 12 additions & 16 deletions src/frontends/pytorch/src/op/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -46,23 +48,17 @@ OutputVector translate_expand_fx(const NodeContext& context) {
num_inputs_check(context, 2, num_inputs);
auto x = context.get_input(0);
std::vector<int32_t> shape_vec;
auto sizes = context.get_input(1);
if (num_inputs != 2) {
for (size_t i = 1; i < num_inputs; i++) {
auto a = context.get_input_from_visible_context(i).get_node_shared_ptr();
auto shape_input = context.get_input(static_cast<int>(i));
if (std::dynamic_pointer_cast<ov::op::v0::Parameter>(a) ||
shape_input.get_partial_shape().rank().is_dynamic() ||
shape_input.get_partial_shape().rank().get_length() == 0) {
shape_vec.push_back(-1);
} else {
auto val = context.const_input<int32_t>(i);
shape_vec.push_back(val);
}
}
sizes = ov::op::v0::Constant::create(element::i32, Shape{num_inputs - 1}, shape_vec);
if (context.get_input_type(1).is<type::List>()) {
auto concat = concat_list_from_inputs(context, 1, num_inputs);
return base_expand(context, x, concat);
} else {
auto x = context.get_input(0);
auto sizes = context.get_input(1);
// TODO: figure out what implicit means
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input<bool>(2) == false,
"Unexpected value of implicit for expand operation");
return base_expand(context, x, sizes);
}
return base_expand(context, x, sizes);
};

} // namespace op
Expand Down
13 changes: 10 additions & 3 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
Expand Down Expand Up @@ -74,9 +75,15 @@ OutputVector translate_full(const NodeContext& context) {
OutputVector translate_full_fx(const NodeContext& context) {
// aten.full.default([16, 16], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'),
// pin_memory = False)
num_inputs_check(context, 2, 2);
auto sizes = context.get_input(0);
auto value = context.get_input(1);
auto num_inputs = context.get_input_size();
num_inputs_check(context, 2, num_inputs);
ov::Output<ov::Node> sizes;
if (context.get_input_type(0).is<type::List>()) {
sizes = concat_list_from_inputs(context, 0, num_inputs - 1);
} else {
sizes = context.get_input(0);
}
auto value = context.get_input(static_cast<int>(num_inputs - 1));

auto filled_tensor = base_translate_full(context, sizes, value);
if (context.has_attribute("dtype")) {
Expand Down
51 changes: 2 additions & 49 deletions src/frontends/pytorch/src/op/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,55 +31,8 @@ OutputVector translate_reshape_fx(const NodeContext& context) {
num_inputs_check(context, 2, num_inputs);
std::vector<int32_t> shape_vec;
if (context.get_input_type(1).is<type::List>()) {
int num_dyn_dims = 0;
for (size_t i = 1; i < num_inputs; i++) {
auto shape_input = context.get_input(static_cast<int>(i));
if (context.get_input_type(i).as<type::List>().element_type.is<type::PyScalar>()) {
auto const_val = context.const_input<int32_t>(i);
shape_vec.push_back(const_val);
} else {
// Set dimension to be dynamic if it's coming from an argument or another node
shape_vec.push_back(-1);
num_dyn_dims++;
}
}
// We cannot use multiple -1s if there are more than 1 dynamic dimensions
if (num_dyn_dims >= 2) {
auto inp_shape = context.get_input(0).get_partial_shape();
// If there are multiple dynamic dymensions, we cannot support inputs with dynamic rank
if (inp_shape.rank().is_static()) {
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0}));
if (inp_shape.size() >= 3 && inp_shape.size() + 1 == shape_vec.size() && shape_vec[0] == 1 &&
inp_shape[0] == shape_vec[1]) {
// [N, ...] -> [1, N, ...] Can be translated to Unsqueeze
auto unsqueeze =
context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(context.get_input(0), zero));
return {unsqueeze};
} else if (shape_vec.size() >= 3 && shape_vec.size() + 1 == inp_shape.size() && inp_shape[0] == 1 &&
inp_shape[1] == shape_vec[0]) {
// [1, N, ...] -> [N, ...] Can be translated to Squeeze
auto squeeze = context.mark_node(std::make_shared<ov::op::v0::Squeeze>(context.get_input(0), zero));
return {squeeze};
} else if (inp_shape.size() == shape_vec.size()) {
// If the input rank is equal to output rank, we can use 0s in place of dynamic dimensions
for (size_t k = 0; k < shape_vec.size(); k++) {
if (shape_vec[k] == -1)
shape_vec[k] = 0;
}
} else {
FRONT_END_GENERAL_CHECK(
false,
"Cannot support reshape with multiple dynamic dimensions for unequal ranks");
}
} else {
FRONT_END_GENERAL_CHECK(
false,
"Cannot support reshape with multiple dynamic dimensions for dynamic input ranks");
}
}

auto shape_const = ov::op::v0::Constant::create(element::i32, Shape{num_inputs - 1}, shape_vec);
auto reshape = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), shape_const, true);
auto concat = concat_list_from_inputs(context, 1, num_inputs);
auto reshape = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), concat, true);
return {context.mark_node(reshape)};
} else {
auto shape_input = context.get_input(1);
Expand Down
12 changes: 6 additions & 6 deletions src/frontends/pytorch/src/op/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/reshape.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -27,11 +27,11 @@ OutputVector translate_slice_common(const NodeContext& context,
int start_idx;
int end_idx;
int step_idx;
auto axis_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto dims_1d_shape = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
if (num_inputs == 5) {
dim = context.get_input(1);
if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) {
dim = context.mark_node(std::make_shared<v0::Unsqueeze>(dim, axis_0));
dim = context.mark_node(std::make_shared<v1::Reshape>(dim, dims_1d_shape, false));
}
start_idx = 2;
end_idx = 3;
Expand All @@ -49,7 +49,7 @@ OutputVector translate_slice_common(const NodeContext& context,
if (!context.input_is_none(start_idx)) {
start = context.get_input(start_idx);
if (start.get_partial_shape().rank().is_dynamic() || start.get_partial_shape().rank().get_length() == 0) {
start = context.mark_node(std::make_shared<v0::Unsqueeze>(start, axis_0));
start = context.mark_node(std::make_shared<v1::Reshape>(start, dims_1d_shape, false));
}
} else {
start = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
Expand All @@ -61,7 +61,7 @@ OutputVector translate_slice_common(const NodeContext& context,
// TODO: Find a better way to solve the issue with dynamic ranks for "end"
if ((stop_dynamic_rank_unsqueeze && end.get_partial_shape().rank().is_dynamic()) ||
(!(end.get_partial_shape().rank().is_dynamic()) && end.get_partial_shape().rank().get_length() == 0)) {
end = context.mark_node(std::make_shared<v0::Unsqueeze>(end, axis_0));
end = context.mark_node(std::make_shared<v1::Reshape>(end, dims_1d_shape, false));
}
} else {
end = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {INT_MAX}));
Expand All @@ -70,7 +70,7 @@ OutputVector translate_slice_common(const NodeContext& context,
if (!context.input_is_none(step_idx)) {
step = context.get_input(step_idx);
if (step.get_partial_shape().rank().is_dynamic() || step.get_partial_shape().rank().get_length() == 0) {
step = context.mark_node(std::make_shared<v0::Unsqueeze>(step, axis_0));
step = context.mark_node(std::make_shared<v1::Reshape>(step, dims_1d_shape, false));
}
} else {
step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
Expand Down
10 changes: 5 additions & 5 deletions src/frontends/pytorch/src/op/slice_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) {
ov::Output<ov::Node> dim;
if (!context.input_is_none(2)) {
dim = context.get_input(2);
if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) {
if (dim.get_partial_shape().rank().is_static() && dim.get_partial_shape().rank().get_length() == 0) {
dim = context.mark_node(std::make_shared<v0::Unsqueeze>(dim, axis_0));
}
} else {
Expand All @@ -35,7 +35,7 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) {
ov::Output<ov::Node> start;
if (!context.input_is_none(3)) {
start = context.get_input(3);
if (start.get_partial_shape().rank().is_dynamic() || start.get_partial_shape().rank().get_length() == 0) {
if (start.get_partial_shape().rank().is_static() && start.get_partial_shape().rank().get_length() == 0) {
start = context.mark_node(std::make_shared<v0::Unsqueeze>(start, axis_0));
}
} else {
Expand All @@ -44,7 +44,7 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) {
ov::Output<ov::Node> end;
if (!context.input_is_none(4)) {
end = context.get_input(4);
if (end.get_partial_shape().rank().is_dynamic() || end.get_partial_shape().rank().get_length() == 0) {
if (end.get_partial_shape().rank().is_static() && end.get_partial_shape().rank().get_length() == 0) {
end = context.mark_node(std::make_shared<v0::Unsqueeze>(end, axis_0));
}
} else {
Expand All @@ -53,7 +53,7 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) {
ov::Output<ov::Node> step;
if (!context.input_is_none(5)) {
step = context.get_input(5);
if (step.get_partial_shape().rank().is_dynamic() || step.get_partial_shape().rank().get_length() == 0) {
if (step.get_partial_shape().rank().is_static() && step.get_partial_shape().rank().get_length() == 0) {
step = context.mark_node(std::make_shared<v0::Unsqueeze>(step, axis_0));
}
} else {
Expand All @@ -65,4 +65,4 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
16 changes: 14 additions & 2 deletions src/frontends/pytorch/src/op/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,22 @@ OutputVector translate_chunk_fx(const NodeContext& context) {
num_inputs_check(context, 3, 3);
auto num_chunks = context.const_input<int>(1);
auto dim = context.get_input(2);

std::shared_ptr<ov::Node> chunk;
auto dim_val = context.const_input<int>(2);

auto shape = context.get_input(0).get_partial_shape();
if (shape.rank().is_dynamic()) {
size_t num_splits = context.get_decoder()->output_list_size();
std::vector<int32_t> split_lengths_vec;
for (size_t i = 0; i < num_splits - 1; i++) {
split_lengths_vec.push_back(num_chunks);
}
split_lengths_vec.push_back(-1);
auto split_lengths =
context.mark_node(v0::Constant::create(element::i32, Shape{num_splits}, split_lengths_vec));
auto split = context.mark_node(std::make_shared<v1::VariadicSplit>(context.get_input(0), dim, split_lengths));
return {context.mark_node(make_list_construct(split->outputs()))};
}
auto dim_val = context.const_input<int>(2);
if (dim_val < 0) {
dim_val = static_cast<int>(shape.rank().get_length()) + dim_val;
}
Expand Down
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten._scaled_dot_product_flash_attention_for_cpu.default", op::translate_scaled_dot_product_attention_fx},
{"aten._softmax.default", op::translate_softmax_fx},
{"aten._to_copy.default", op::translate_to_fx},
{"aten._unsafe_view.default", op::translate_reshape},
{"aten._unsafe_view.default", op::translate_reshape_fx},
{"aten.abs.default", op::translate_1to1_match_1_inputs<opset10::Abs>},
{"aten.acos.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acos>},
{"aten.acosh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>},
Expand Down Expand Up @@ -970,6 +970,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.topk.default", op::translate_topk_fx},
{"aten.transpose.int", op::translate_transpose},
{"aten.tril.default", op::translate_tril},
{"aten.triu.default", op::translate_triu},
{"aten.unbind.int", op::translate_unbind_int_fx},
{"aten.unfold.default", op::translate_unfold},
{"aten.unsqueeze.default", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
Expand Down
24 changes: 24 additions & 0 deletions src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,30 @@ Output<Node> masked_fill(ov::pass::NodeRegistry& rg,
return rg.make<opset10::Select>(bool_mask, _value, data);
}

Output<Node> concat_list_from_inputs(const NodeContext& context, size_t begin, size_t end) {
OutputVector list_elems;
for (size_t i = begin; i < end; i++) {
if (context.get_input_type(i).as<type::List>().element_type.is<type::PyScalar>()) {
auto const_val = context.const_input<int64_t>(i);
std::vector<int64_t> dim_vec;
dim_vec.push_back(const_val);
auto dim_const = ov::op::v0::Constant::create(element::i64, Shape{1}, dim_vec);
list_elems.push_back(dim_const);
} else {
auto input_dim = context.get_input(static_cast<int>(i));
if (input_dim.get_partial_shape().rank() == 0) {
auto zero = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
auto unsqueezed_dim = context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(input_dim, zero));
list_elems.push_back(unsqueezed_dim);
} else {
list_elems.push_back(input_dim);
}
}
}
auto concat = std::make_shared<ov::op::v0::Concat>(list_elems, 0);
return concat;
}

} // namespace pytorch
} // namespace frontend
} // namespace ov
5 changes: 5 additions & 0 deletions src/frontends/pytorch/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ Output<Node> masked_fill(ov::pass::NodeRegistry& rg,
const Output<Node>& mask,
const Output<Node>& value);

Output<Node> concat_list_from_inputs(const NodeContext& context, size_t begin, size_t end);

namespace op {
template <OutputVector (*T)(const NodeContext&), size_t idx = 0>
OutputVector inplace_op(const NodeContext& context) {
Expand Down Expand Up @@ -267,6 +269,9 @@ class DummyDecoder : public TorchDecoder {
virtual size_t num_of_outputs() const override {
FRONT_END_NOT_IMPLEMENTED(num_of_outputs);
}
virtual size_t output_list_size() const override {
FRONT_END_NOT_IMPLEMENTED(output_list_size);
}
virtual const std::vector<size_t>& outputs() const override {
FRONT_END_NOT_IMPLEMENTED(outputs);
}
Expand Down
3 changes: 2 additions & 1 deletion tests/layer_tests/pytorch_tests/test_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ def __init__(self, dims):
super(aten_expand, self).__init__()
self.dims = dims

# TODO: Remove the add op after fixing the issue with expand being the last node
def forward(self, x, dym):
return x.expand((self.dims+(dym,)))
return torch.add(x.expand((self.dims+(dym,))), 0)

ref_net = None

Expand Down

0 comments on commit ac3cdae

Please sign in to comment.