From 1d545b259562cc3f9a335cef484883ab55814f9f Mon Sep 17 00:00:00 2001 From: ynimmaga Date: Wed, 12 Jun 2024 14:01:18 -0700 Subject: [PATCH] Dynamic shape support for llama3 --- .../pytorch/torchdynamo/op_support.py | 1 + src/frontends/pytorch/src/op/arange.cpp | 4 ++ src/frontends/pytorch/src/op/expand.cpp | 33 +++++++----- src/frontends/pytorch/src/op/full.cpp | 26 ++++++++-- src/frontends/pytorch/src/op/reshape.cpp | 52 ++++--------------- .../pytorch/src/op/slice_scatter.cpp | 4 +- src/frontends/pytorch/src/op_table.cpp | 3 +- 7 files changed, 62 insertions(+), 61 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index c2d08bd14638df..f1d0fdb0d36955 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -241,6 +241,7 @@ def __init__(self, options): "torch.ops.aten.transpose.int": None, "torch.ops.aten.tril.default": None, "torch.ops.aten.tril_.default": None, + "torch.ops.aten.triu.default": None, "torch.ops.aten.unbind.int": None, "torch.ops.aten.unfold.default": None, "torch.ops.aten.unsqueeze.default": None, diff --git a/src/frontends/pytorch/src/op/arange.cpp b/src/frontends/pytorch/src/op/arange.cpp index d4542f533c0bc0..41e0f8d8bef305 100644 --- a/src/frontends/pytorch/src/op/arange.cpp +++ b/src/frontends/pytorch/src/op/arange.cpp @@ -7,6 +7,7 @@ #include "openvino/op/convert.hpp" #include "openvino/op/convert_like.hpp" #include "openvino/op/range.hpp" +#include "openvino/op/squeeze.hpp" #include "pt_framework_node.hpp" #include "utils.hpp" @@ -108,6 +109,9 @@ OutputVector translate_arange_fx(const NodeContext& context) { if (context.has_attribute("dtype")) { dtype = context.get_attribute("dtype"); } + if (end.get_partial_shape().rank().is_dynamic()) { + end = context.mark_node(std::make_shared(end, zero)); + } auto range = context.mark_node(std::make_shared(start, end, step, dtype)); if (!context.has_attribute("dtype")) { range = context.mark_node(std::make_shared(range, context.get_input(0))); diff --git a/src/frontends/pytorch/src/op/expand.cpp b/src/frontends/pytorch/src/op/expand.cpp index 8e9ce327e647d5..921cd7956cac89 100644 --- a/src/frontends/pytorch/src/op/expand.cpp +++ b/src/frontends/pytorch/src/op/expand.cpp @@ -5,6 +5,7 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/abs.hpp" #include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" #include "openvino/op/shape_of.hpp" #include "utils.hpp" @@ -46,23 +47,31 @@ OutputVector translate_expand_fx(const NodeContext& context) { num_inputs_check(context, 2, num_inputs); auto x = context.get_input(0); std::vector shape_vec; - auto sizes = context.get_input(1); - if (num_inputs != 2) { + if (context.get_input_type(1).is()) { + std::deque> list_elems; for (size_t i = 1; i < num_inputs; i++) { - auto a = context.get_input_from_visible_context(i).get_node_shared_ptr(); - auto shape_input = context.get_input(static_cast(i)); - if (std::dynamic_pointer_cast(a) || - shape_input.get_partial_shape().rank().is_dynamic() || - shape_input.get_partial_shape().rank().get_length() == 0) { - shape_vec.push_back(-1); + if (context.get_input_type(i).as().element_type.is()) { + auto const_val = context.const_input(i); + std::vector dim_vec; + dim_vec.push_back(const_val); + auto dim_const = ov::op::v0::Constant::create(element::i32, Shape{1}, dim_vec); + list_elems.push_back(dim_const); } else { - auto val = context.const_input(i); - shape_vec.push_back(val); + auto converted_dim = context.mark_node(std::make_shared(context.get_input(static_cast(i)), element::i32)); + list_elems.push_back(converted_dim); } } - sizes = ov::op::v0::Constant::create(element::i32, Shape{num_inputs - 1}, shape_vec); + auto concat = std::make_shared(OutputVector(list_elems.begin(), list_elems.end()), 0); + return base_expand(context, x, concat); + } else { + auto x = context.get_input(0); + auto sizes = context.get_input(1); + // TODO: figure out what implicit means + PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input(2) == false, + "Unexpected value of implicit for expand operation"); + return base_expand(context, x, sizes); } - return base_expand(context, x, sizes); + }; } // namespace op diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index dc3fa1f677b58e..55cd76416570c5 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -8,6 +8,7 @@ #include "openvino/op/constant.hpp" #include "openvino/op/convert_like.hpp" #include "openvino/op/divide.hpp" +#include "openvino/op/concat.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/power.hpp" @@ -74,9 +75,28 @@ OutputVector translate_full(const NodeContext& context) { OutputVector translate_full_fx(const NodeContext& context) { // aten.full.default([16, 16], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), // pin_memory = False) - num_inputs_check(context, 2, 2); - auto sizes = context.get_input(0); - auto value = context.get_input(1); + auto num_inputs = context.get_input_size(); + num_inputs_check(context, 2, num_inputs); + ov::Output sizes; + if (context.get_input_type(0).is()) { + std::deque> list_elems; + for (size_t i = 0; i < num_inputs-1; i++) { + if (context.get_input_type(i).as().element_type.is()) { + auto const_val = context.const_input(i); + std::vector dim_vec; + dim_vec.push_back(const_val); + auto dim_const = ov::op::v0::Constant::create(element::i32, Shape{1}, dim_vec); + list_elems.push_back(dim_const); + } else { + auto converted_dim = context.mark_node(std::make_shared(context.get_input(static_cast(i)), element::i32)); + list_elems.push_back(converted_dim); + } + } + sizes = std::make_shared(OutputVector(list_elems.begin(), list_elems.end()), 0); + } else { + sizes = context.get_input(0); + } + auto value = context.get_input(num_inputs-1); auto filled_tensor = base_translate_full(context, sizes, value); if (context.has_attribute("dtype")) { diff --git a/src/frontends/pytorch/src/op/reshape.cpp b/src/frontends/pytorch/src/op/reshape.cpp index edea4c7aefb44a..dac29bfd31ff9e 100644 --- a/src/frontends/pytorch/src/op/reshape.cpp +++ b/src/frontends/pytorch/src/op/reshape.cpp @@ -31,55 +31,21 @@ OutputVector translate_reshape_fx(const NodeContext& context) { num_inputs_check(context, 2, num_inputs); std::vector shape_vec; if (context.get_input_type(1).is()) { - int num_dyn_dims = 0; + std::deque> list_elems; for (size_t i = 1; i < num_inputs; i++) { - auto shape_input = context.get_input(static_cast(i)); if (context.get_input_type(i).as().element_type.is()) { auto const_val = context.const_input(i); - shape_vec.push_back(const_val); + std::vector dim_vec; + dim_vec.push_back(const_val); + auto dim_const = ov::op::v0::Constant::create(element::i32, Shape{1}, dim_vec); + list_elems.push_back(dim_const); } else { - // Set dimension to be dynamic if it's coming from an argument or another node - shape_vec.push_back(-1); - num_dyn_dims++; + auto converted_dim = context.mark_node(std::make_shared(context.get_input(static_cast(i)), element::i32)); + list_elems.push_back(converted_dim); } } - // We cannot use multiple -1s if there are more than 1 dynamic dimensions - if (num_dyn_dims >= 2) { - auto inp_shape = context.get_input(0).get_partial_shape(); - // If there are multiple dynamic dymensions, we cannot support inputs with dynamic rank - if (inp_shape.rank().is_static()) { - auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0})); - if (inp_shape.size() >= 3 && inp_shape.size() + 1 == shape_vec.size() && shape_vec[0] == 1 && - inp_shape[0] == shape_vec[1]) { - // [N, ...] -> [1, N, ...] Can be translated to Unsqueeze - auto unsqueeze = - context.mark_node(std::make_shared(context.get_input(0), zero)); - return {unsqueeze}; - } else if (shape_vec.size() >= 3 && shape_vec.size() + 1 == inp_shape.size() && inp_shape[0] == 1 && - inp_shape[1] == shape_vec[0]) { - // [1, N, ...] -> [N, ...] Can be translated to Squeeze - auto squeeze = context.mark_node(std::make_shared(context.get_input(0), zero)); - return {squeeze}; - } else if (inp_shape.size() == shape_vec.size()) { - // If the input rank is equal to output rank, we can use 0s in place of dynamic dimensions - for (size_t k = 0; k < shape_vec.size(); k++) { - if (shape_vec[k] == -1) - shape_vec[k] = 0; - } - } else { - FRONT_END_GENERAL_CHECK( - false, - "Cannot support reshape with multiple dynamic dimensions for unequal ranks"); - } - } else { - FRONT_END_GENERAL_CHECK( - false, - "Cannot support reshape with multiple dynamic dimensions for dynamic input ranks"); - } - } - - auto shape_const = ov::op::v0::Constant::create(element::i32, Shape{num_inputs - 1}, shape_vec); - auto reshape = std::make_shared(context.get_input(0), shape_const, true); + auto concat = std::make_shared(OutputVector(list_elems.begin(), list_elems.end()), 0); + auto reshape = std::make_shared(context.get_input(0), concat, true); return {context.mark_node(reshape)}; } else { auto shape_input = context.get_input(1); diff --git a/src/frontends/pytorch/src/op/slice_scatter.cpp b/src/frontends/pytorch/src/op/slice_scatter.cpp index d522b6d63d3d81..fbe98d65ca12da 100644 --- a/src/frontends/pytorch/src/op/slice_scatter.cpp +++ b/src/frontends/pytorch/src/op/slice_scatter.cpp @@ -44,7 +44,7 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) { ov::Output end; if (!context.input_is_none(4)) { end = context.get_input(4); - if (end.get_partial_shape().rank().is_dynamic() || end.get_partial_shape().rank().get_length() == 0) { + if (!(end.get_partial_shape().rank().is_dynamic()) && end.get_partial_shape().rank().get_length() == 0) { end = context.mark_node(std::make_shared(end, axis_0)); } } else { @@ -65,4 +65,4 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) { } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 9cc73f854bbf6a..8693b66561f036 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -758,7 +758,7 @@ const std::map get_supported_ops_fx() { {"aten._scaled_dot_product_flash_attention_for_cpu.default", op::translate_scaled_dot_product_attention_fx}, {"aten._softmax.default", op::translate_softmax_fx}, {"aten._to_copy.default", op::translate_to_fx}, - {"aten._unsafe_view.default", op::translate_reshape}, + {"aten._unsafe_view.default", op::translate_reshape_fx}, {"aten.abs.default", op::translate_1to1_match_1_inputs}, {"aten.acos.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.acosh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, @@ -963,6 +963,7 @@ const std::map get_supported_ops_fx() { {"aten.topk.default", op::translate_topk_fx}, {"aten.transpose.int", op::translate_transpose}, {"aten.tril.default", op::translate_tril}, + {"aten.triu.default", op::translate_triu}, {"aten.unbind.int", op::translate_unbind_int_fx}, {"aten.unfold.default", op::translate_unfold}, {"aten.unsqueeze.default", op::translate_1to1_match_2_inputs},