From bf4b3467db98814bf0de65f221be63960846abc0 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Sat, 16 Sep 2023 18:41:32 -0700 Subject: [PATCH] TorchFX: Additional op support for LLMs --- .../pytorch/torchdynamo/op_support.py | 3 ++ src/frontends/pytorch/src/op/any.cpp | 32 +++++++++++++++++ src/frontends/pytorch/src/op/isinf.cpp | 26 ++++++++++++++ src/frontends/pytorch/src/op/isnan.cpp | 26 ++++++++++++++ src/frontends/pytorch/src/op/sort.cpp | 32 +++++++++++++++++ src/frontends/pytorch/src/op/topk.cpp | 34 +++++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 16 ++++++--- 7 files changed, 165 insertions(+), 4 deletions(-) create mode 100644 src/frontends/pytorch/src/op/any.cpp create mode 100644 src/frontends/pytorch/src/op/isinf.cpp create mode 100644 src/frontends/pytorch/src/op/isnan.cpp diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 70f6a69bf9658e..e76dc0e854783f 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -38,6 +38,7 @@ def __init__(self): "torch.ops.aten.add_.Tensor": None, "torch.ops.aten.addmm.default": None, "torch.ops.aten.all.default": None, + "torch.ops.aten.any.default": None, "torch.ops.aten.arange.start": None, "torch.ops.aten.arange.start_step": None, "torch.ops.aten.arange.default": None, @@ -71,6 +72,8 @@ def __init__(self): "torch.ops.aten.hardswish_.default": None, "torch.ops.aten.hardtanh_.default": None, "torch.ops.aten.index.Tensor": None, + "torch.ops.aten.isinf.default": None, + "torch.ops.aten.isnan.default": None, "torch.ops.aten.le.Scalar": None, "torch.ops.aten.lift_fresh_copy.default": None, "torch.ops.aten.linalg_vector_norm.default": None, diff --git a/src/frontends/pytorch/src/op/any.cpp b/src/frontends/pytorch/src/op/any.cpp new file mode 100644 index 00000000000000..22b8353715cf0c --- /dev/null +++ b/src/frontends/pytorch/src/op/any.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/reduce_logical_or.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/not_equal.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_any_fx(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto x = context.get_input(0); + auto num_inputs = context.get_input_size(); + bool keep_dims = false; + auto const_minus_one = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {-1})); + auto flatten_source = context.mark_node(std::make_shared(x, const_minus_one, false)); + auto const_zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0})); + auto any = context.mark_node(std::make_shared(flatten_source, const_zero, keep_dims)); + return {any}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op/isinf.cpp b/src/frontends/pytorch/src/op/isinf.cpp new file mode 100644 index 00000000000000..6e339d3bfa520c --- /dev/null +++ b/src/frontends/pytorch/src/op/isinf.cpp @@ -0,0 +1,26 @@ + +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/add.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/is_inf.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_isinf_fx(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto input = context.get_input(0); + return {context.mark_node(std::make_shared(input))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op/isnan.cpp b/src/frontends/pytorch/src/op/isnan.cpp new file mode 100644 index 00000000000000..314d8a74cd1798 --- /dev/null +++ b/src/frontends/pytorch/src/op/isnan.cpp @@ -0,0 +1,26 @@ + +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/add.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/is_nan.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_isnan_fx(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto input = context.get_input(0); + return {context.mark_node(std::make_shared(input))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op/sort.cpp b/src/frontends/pytorch/src/op/sort.cpp index 7e75a98c6cebd8..4b2c1a8e5047f4 100644 --- a/src/frontends/pytorch/src/op/sort.cpp +++ b/src/frontends/pytorch/src/op/sort.cpp @@ -2,7 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 // #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/util/framework_node.hpp" #include "openvino/opsets/opset11.hpp" +#include "openvino/op/convert.hpp" #include "utils.hpp" namespace ov { namespace frontend { @@ -41,6 +43,36 @@ OutputVector translate_sort(const NodeContext& context) { return topk->outputs(); }; +OutputVector translate_sort_fx(const NodeContext& context) { + // aten.sort.default(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + num_inputs_check(context, 1, 3); + const auto input_tensor = context.get_input(0); + bool descending = false; + int64_t dim = -1; + + if (!context.input_is_none(1)) { + dim = context.const_input(1); + } + if (!context.input_is_none(1)) { + descending = context.const_input(2); + } + + auto mode = descending ? ov::op::TopKMode::MAX : ov::op::TopKMode::MIN; + auto zero_axis = context.mark_node(opset11::Constant::create(element::i32, Shape{1}, {0})); + auto dim_axis = context.mark_node(opset11::Constant::create(element::i64, Shape{1}, {dim})); + auto shape = context.mark_node(std::make_shared(input_tensor)); + auto k_values_node = context.mark_node(std::make_shared(shape, dim_axis, zero_axis)); + auto k_values = context.mark_node(std::make_shared(k_values_node)); + auto topk = context.mark_node(std::make_shared(input_tensor, + k_values, + dim, + mode, + ov::op::TopKSortType::SORT_VALUES, + element::i64)); + auto indices = context.mark_node(std::make_shared(topk->output(1), element::i64)); + return {context.mark_node(make_list_construct(OutputVector({topk->output(0), indices})))}; +}; + OutputVector translate_argsort(const NodeContext& context) { auto sort = translate_sort(context); return {sort[1]}; diff --git a/src/frontends/pytorch/src/op/topk.cpp b/src/frontends/pytorch/src/op/topk.cpp index 06916c4ea03e2f..345cbabb86299d 100644 --- a/src/frontends/pytorch/src/op/topk.cpp +++ b/src/frontends/pytorch/src/op/topk.cpp @@ -5,6 +5,7 @@ #include "openvino/op/topk.hpp" #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/util/framework_node.hpp" #include "openvino/op/convert.hpp" #include "utils.hpp" @@ -41,6 +42,39 @@ OutputVector translate_topk(const NodeContext& context) { return {topk->output(0), indices}; }; +OutputVector translate_topk_fx(const NodeContext& context) { + // aten.topk.default(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> Tuple[Tensor, Tensor] + num_inputs_check(context, 2, 5); + const auto input_tensor = context.get_input(0); + auto k = context.get_input(1); + int64_t axis{-1}; + bool largest = true; + bool sorted = true; + auto mode = TopKMode::MIN; + auto sort = TopKSortType::NONE; + + if (!context.input_is_none(2)) { + axis = context.const_input(2); + } + if (!context.input_is_none(3)) { + largest = context.const_input(3); + } + if (!context.input_is_none(4)) { + sorted = context.const_input(4); + } + if (largest) { + mode = TopKMode::MAX; + } + if (sorted) { + sort = TopKSortType::SORT_VALUES; + } + + auto topk = context.mark_node(std::make_shared(input_tensor, k, axis, mode, sort)); + auto indices = context.mark_node(std::make_shared(topk->output(1), element::i64)); + + return {context.mark_node(make_list_construct(OutputVector({topk->output(0), indices})))}; +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index daeac6feb03c91..77463171d6a860 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -186,6 +186,7 @@ OP_CONVERTER(translate_quantized_convnd); OP_CONVERTER(translate_quantized_convnd_relu); OP_CONVERTER(translate_quantized_linear); // Torch FX Translations +OP_CONVERTER(translate_any_fx); OP_CONVERTER(translate_arange_fx); OP_CONVERTER(translate_batch_norm_fx); OP_CONVERTER(translate_cat_fx); @@ -193,12 +194,16 @@ OP_CONVERTER(translate_chunk_fx); OP_CONVERTER(translate_expand_fx); OP_CONVERTER(translate_group_norm_fx); OP_CONVERTER(translate_index_fx); +OP_CONVERTER(translate_isinf_fx); +OP_CONVERTER(translate_isnan_fx); OP_CONVERTER(translate_layer_norm_fx); OP_CONVERTER(translate_max_poolnd_fx); OP_CONVERTER(translate_slice_fx); OP_CONVERTER(translate_softmax_fx); +OP_CONVERTER(translate_sort_fx); OP_CONVERTER(translate_split_with_sizes_fx); OP_CONVERTER(translate_stack_fx); +OP_CONVERTER(translate_topk_fx); OP_CONVERTER(translate_transpose_fx); } // namespace op @@ -501,6 +506,7 @@ const std::map get_supported_ops_fx() { {"aten.add_.Tensor", op::translate_add}, {"aten.addmm.default", op::translate_addmm}, {"aten.all.default", op::translate_all}, + {"aten.any.default", op::translate_any_fx}, {"aten.arange.start", op::translate_arange_fx}, {"aten.arange.start_step", op::translate_arange_fx}, {"aten.arange.default", op::translate_arange_fx}, @@ -525,7 +531,7 @@ const std::map get_supported_ops_fx() { {"aten.eq.Tensor", op::translate_1to1_match_2_inputs_align_types}, {"aten.exp.default", op::translate_1to1_match_1_inputs}, {"aten.expand.default", op::translate_expand_fx}, - {"aten.fill_.Tensor", op::inplace_op}, + {"aten.fill_.Tensor", op::inplace_op}, {"aten.full.default", op::translate_full}, {"aten.gather.default", op::translate_gather}, {"aten.gelu.default", op::translate_gelu}, @@ -534,6 +540,8 @@ const std::map get_supported_ops_fx() { {"aten.hardswish_.default", op::inplace_op>}, {"aten.hardtanh_.default", op::inplace_op}, {"aten.index.Tensor", op::translate_index_fx}, + {"aten.isinf.default", op::translate_isinf_fx}, + {"aten.isnan.default", op::translate_isnan_fx}, {"aten.le.Scalar", op::translate_1to1_match_2_inputs_align_types}, {"aten.lift_fresh_copy.default", op::skip_node}, {"aten.linalg_vector_norm.default", op::translate_linalg_vector_norm}, @@ -565,14 +573,14 @@ const std::map get_supported_ops_fx() { {"aten.relu_.default", op::inplace_op>}, {"aten.rsqrt.default", op::translate_rsqrt}, {"aten.rsub.Scalar", op::translate_rsub}, - {"aten.scatter.src", op::translate_scatter}, + {"aten.scatter.src", op::translate_scatter}, {"aten.select.int", op::translate_select}, {"aten.sigmoid.default", op::translate_1to1_match_1_inputs}, {"aten.silu.default", op::translate_1to1_match_1_inputs}, {"aten.silu_.default", op::inplace_op>}, {"aten.sin.default", op::translate_1to1_match_1_inputs}, {"aten.slice.Tensor", op::translate_slice_fx}, - {"aten.sort.default", op::translate_sort}, + {"aten.sort.default", op::translate_sort_fx}, {"aten.split.Tensor", op::translate_chunk_fx}, {"aten.split_with_sizes.default", op::translate_split_with_sizes_fx}, {"aten.stack.default", op::translate_stack_fx}, @@ -581,7 +589,7 @@ const std::map get_supported_ops_fx() { {"aten.sub.Tensor", op::translate_sub}, {"aten.t.default", op::translate_t}, {"aten.tanh.default", op::translate_1to1_match_1_inputs}, - {"aten.topk.default", op::translate_topk}, + {"aten.topk.default", op::translate_topk_fx}, {"aten.transpose.int", op::translate_transpose}, {"aten.tril.default", op::translate_tril}, {"aten.unsqueeze.default", op::translate_1to1_match_2_inputs},