Skip to content

Commit

Permalink
Dynamic shape support for llama3
Browse files Browse the repository at this point in the history
  • Loading branch information
ynimmaga committed Jun 12, 2024
1 parent e6fe146 commit 1d545b2
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(self, options):
"torch.ops.aten.transpose.int": None,
"torch.ops.aten.tril.default": None,
"torch.ops.aten.tril_.default": None,
"torch.ops.aten.triu.default": None,
"torch.ops.aten.unbind.int": None,
"torch.ops.aten.unfold.default": None,
"torch.ops.aten.unsqueeze.default": None,
Expand Down
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op/arange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/squeeze.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -108,6 +109,9 @@ OutputVector translate_arange_fx(const NodeContext& context) {
if (context.has_attribute("dtype")) {
dtype = context.get_attribute<element::Type>("dtype");
}
if (end.get_partial_shape().rank().is_dynamic()) {
end = context.mark_node(std::make_shared<ov::op::v0::Squeeze>(end, zero));
}
auto range = context.mark_node(std::make_shared<v4::Range>(start, end, step, dtype));
if (!context.has_attribute("dtype")) {
range = context.mark_node(std::make_shared<v1::ConvertLike>(range, context.get_input(0)));
Expand Down
33 changes: 21 additions & 12 deletions src/frontends/pytorch/src/op/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/shape_of.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -46,23 +47,31 @@ OutputVector translate_expand_fx(const NodeContext& context) {
num_inputs_check(context, 2, num_inputs);
auto x = context.get_input(0);
std::vector<int32_t> shape_vec;
auto sizes = context.get_input(1);
if (num_inputs != 2) {
if (context.get_input_type(1).is<type::List>()) {
std::deque<Output<Node>> list_elems;
for (size_t i = 1; i < num_inputs; i++) {
auto a = context.get_input_from_visible_context(i).get_node_shared_ptr();
auto shape_input = context.get_input(static_cast<int>(i));
if (std::dynamic_pointer_cast<ov::op::v0::Parameter>(a) ||
shape_input.get_partial_shape().rank().is_dynamic() ||
shape_input.get_partial_shape().rank().get_length() == 0) {
shape_vec.push_back(-1);
if (context.get_input_type(i).as<type::List>().element_type.is<type::PyScalar>()) {
auto const_val = context.const_input<int32_t>(i);
std::vector<int32_t> dim_vec;
dim_vec.push_back(const_val);
auto dim_const = ov::op::v0::Constant::create(element::i32, Shape{1}, dim_vec);
list_elems.push_back(dim_const);
} else {
auto val = context.const_input<int32_t>(i);
shape_vec.push_back(val);
auto converted_dim = context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(static_cast<int>(i)), element::i32));
list_elems.push_back(converted_dim);
}
}
sizes = ov::op::v0::Constant::create(element::i32, Shape{num_inputs - 1}, shape_vec);
auto concat = std::make_shared<ov::op::v0::Concat>(OutputVector(list_elems.begin(), list_elems.end()), 0);
return base_expand(context, x, concat);
} else {
auto x = context.get_input(0);
auto sizes = context.get_input(1);
// TODO: figure out what implicit means
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input<bool>(2) == false,
"Unexpected value of implicit for expand operation");
return base_expand(context, x, sizes);
}
return base_expand(context, x, sizes);

};

} // namespace op
Expand Down
26 changes: 23 additions & 3 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/power.hpp"
Expand Down Expand Up @@ -74,9 +75,28 @@ OutputVector translate_full(const NodeContext& context) {
OutputVector translate_full_fx(const NodeContext& context) {
// aten.full.default([16, 16], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'),
// pin_memory = False)
num_inputs_check(context, 2, 2);
auto sizes = context.get_input(0);
auto value = context.get_input(1);
auto num_inputs = context.get_input_size();
num_inputs_check(context, 2, num_inputs);
ov::Output<ov::Node> sizes;
if (context.get_input_type(0).is<type::List>()) {
std::deque<Output<Node>> list_elems;
for (size_t i = 0; i < num_inputs-1; i++) {
if (context.get_input_type(i).as<type::List>().element_type.is<type::PyScalar>()) {
auto const_val = context.const_input<int32_t>(i);
std::vector<int32_t> dim_vec;
dim_vec.push_back(const_val);
auto dim_const = ov::op::v0::Constant::create(element::i32, Shape{1}, dim_vec);
list_elems.push_back(dim_const);
} else {
auto converted_dim = context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(static_cast<int>(i)), element::i32));
list_elems.push_back(converted_dim);
}
}
sizes = std::make_shared<ov::op::v0::Concat>(OutputVector(list_elems.begin(), list_elems.end()), 0);
} else {
sizes = context.get_input(0);
}
auto value = context.get_input(num_inputs-1);

auto filled_tensor = base_translate_full(context, sizes, value);
if (context.has_attribute("dtype")) {
Expand Down
52 changes: 9 additions & 43 deletions src/frontends/pytorch/src/op/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,55 +31,21 @@ OutputVector translate_reshape_fx(const NodeContext& context) {
num_inputs_check(context, 2, num_inputs);
std::vector<int32_t> shape_vec;
if (context.get_input_type(1).is<type::List>()) {
int num_dyn_dims = 0;
std::deque<Output<Node>> list_elems;
for (size_t i = 1; i < num_inputs; i++) {
auto shape_input = context.get_input(static_cast<int>(i));
if (context.get_input_type(i).as<type::List>().element_type.is<type::PyScalar>()) {
auto const_val = context.const_input<int32_t>(i);
shape_vec.push_back(const_val);
std::vector<int32_t> dim_vec;
dim_vec.push_back(const_val);
auto dim_const = ov::op::v0::Constant::create(element::i32, Shape{1}, dim_vec);
list_elems.push_back(dim_const);
} else {
// Set dimension to be dynamic if it's coming from an argument or another node
shape_vec.push_back(-1);
num_dyn_dims++;
auto converted_dim = context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(static_cast<int>(i)), element::i32));
list_elems.push_back(converted_dim);
}
}
// We cannot use multiple -1s if there are more than 1 dynamic dimensions
if (num_dyn_dims >= 2) {
auto inp_shape = context.get_input(0).get_partial_shape();
// If there are multiple dynamic dymensions, we cannot support inputs with dynamic rank
if (inp_shape.rank().is_static()) {
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0}));
if (inp_shape.size() >= 3 && inp_shape.size() + 1 == shape_vec.size() && shape_vec[0] == 1 &&
inp_shape[0] == shape_vec[1]) {
// [N, ...] -> [1, N, ...] Can be translated to Unsqueeze
auto unsqueeze =
context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(context.get_input(0), zero));
return {unsqueeze};
} else if (shape_vec.size() >= 3 && shape_vec.size() + 1 == inp_shape.size() && inp_shape[0] == 1 &&
inp_shape[1] == shape_vec[0]) {
// [1, N, ...] -> [N, ...] Can be translated to Squeeze
auto squeeze = context.mark_node(std::make_shared<ov::op::v0::Squeeze>(context.get_input(0), zero));
return {squeeze};
} else if (inp_shape.size() == shape_vec.size()) {
// If the input rank is equal to output rank, we can use 0s in place of dynamic dimensions
for (size_t k = 0; k < shape_vec.size(); k++) {
if (shape_vec[k] == -1)
shape_vec[k] = 0;
}
} else {
FRONT_END_GENERAL_CHECK(
false,
"Cannot support reshape with multiple dynamic dimensions for unequal ranks");
}
} else {
FRONT_END_GENERAL_CHECK(
false,
"Cannot support reshape with multiple dynamic dimensions for dynamic input ranks");
}
}

auto shape_const = ov::op::v0::Constant::create(element::i32, Shape{num_inputs - 1}, shape_vec);
auto reshape = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), shape_const, true);
auto concat = std::make_shared<ov::op::v0::Concat>(OutputVector(list_elems.begin(), list_elems.end()), 0);
auto reshape = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), concat, true);
return {context.mark_node(reshape)};
} else {
auto shape_input = context.get_input(1);
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/op/slice_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) {
ov::Output<ov::Node> end;
if (!context.input_is_none(4)) {
end = context.get_input(4);
if (end.get_partial_shape().rank().is_dynamic() || end.get_partial_shape().rank().get_length() == 0) {
if (!(end.get_partial_shape().rank().is_dynamic()) && end.get_partial_shape().rank().get_length() == 0) {
end = context.mark_node(std::make_shared<v0::Unsqueeze>(end, axis_0));
}
} else {
Expand All @@ -65,4 +65,4 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten._scaled_dot_product_flash_attention_for_cpu.default", op::translate_scaled_dot_product_attention_fx},
{"aten._softmax.default", op::translate_softmax_fx},
{"aten._to_copy.default", op::translate_to_fx},
{"aten._unsafe_view.default", op::translate_reshape},
{"aten._unsafe_view.default", op::translate_reshape_fx},
{"aten.abs.default", op::translate_1to1_match_1_inputs<opset10::Abs>},
{"aten.acos.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acos>},
{"aten.acosh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>},
Expand Down Expand Up @@ -963,6 +963,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.topk.default", op::translate_topk_fx},
{"aten.transpose.int", op::translate_transpose},
{"aten.tril.default", op::translate_tril},
{"aten.triu.default", op::translate_triu},
{"aten.unbind.int", op::translate_unbind_int_fx},
{"aten.unfold.default", op::translate_unfold},
{"aten.unsqueeze.default", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
Expand Down

0 comments on commit 1d545b2

Please sign in to comment.