Skip to content

Commit

Permalink
TorchFX: Additional op support for LLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
cavusmustafa committed Sep 17, 2023
1 parent 5015d7a commit bf4b346
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self):
"torch.ops.aten.add_.Tensor": None,
"torch.ops.aten.addmm.default": None,
"torch.ops.aten.all.default": None,
"torch.ops.aten.any.default": None,
"torch.ops.aten.arange.start": None,
"torch.ops.aten.arange.start_step": None,
"torch.ops.aten.arange.default": None,
Expand Down Expand Up @@ -71,6 +72,8 @@ def __init__(self):
"torch.ops.aten.hardswish_.default": None,
"torch.ops.aten.hardtanh_.default": None,
"torch.ops.aten.index.Tensor": None,
"torch.ops.aten.isinf.default": None,
"torch.ops.aten.isnan.default": None,
"torch.ops.aten.le.Scalar": None,
"torch.ops.aten.lift_fresh_copy.default": None,
"torch.ops.aten.linalg_vector_norm.default": None,
Expand Down
32 changes: 32 additions & 0 deletions src/frontends/pytorch/src/op/any.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/reduce_logical_or.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/not_equal.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_any_fx(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
auto num_inputs = context.get_input_size();
bool keep_dims = false;
auto const_minus_one = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {-1}));
auto flatten_source = context.mark_node(std::make_shared<ov::op::v1::Reshape>(x, const_minus_one, false));
auto const_zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0}));
auto any = context.mark_node(std::make_shared<ov::op::v1::ReduceLogicalOr>(flatten_source, const_zero, keep_dims));
return {any};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
26 changes: 26 additions & 0 deletions src/frontends/pytorch/src/op/isinf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/add.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/is_inf.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_isinf_fx(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto input = context.get_input(0);
return {context.mark_node(std::make_shared<ov::op::v10::IsInf>(input))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
26 changes: 26 additions & 0 deletions src/frontends/pytorch/src/op/isnan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/add.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/is_nan.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_isnan_fx(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto input = context.get_input(0);
return {context.mark_node(std::make_shared<ov::op::v10::IsNaN>(input))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
32 changes: 32 additions & 0 deletions src/frontends/pytorch/src/op/sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/opsets/opset11.hpp"
#include "openvino/op/convert.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
Expand Down Expand Up @@ -41,6 +43,36 @@ OutputVector translate_sort(const NodeContext& context) {
return topk->outputs();
};

OutputVector translate_sort_fx(const NodeContext& context) {
// aten.sort.default(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
num_inputs_check(context, 1, 3);
const auto input_tensor = context.get_input(0);
bool descending = false;
int64_t dim = -1;

if (!context.input_is_none(1)) {
dim = context.const_input<int64_t>(1);
}
if (!context.input_is_none(1)) {
descending = context.const_input<bool>(2);
}

auto mode = descending ? ov::op::TopKMode::MAX : ov::op::TopKMode::MIN;
auto zero_axis = context.mark_node(opset11::Constant::create(element::i32, Shape{1}, {0}));
auto dim_axis = context.mark_node(opset11::Constant::create(element::i64, Shape{1}, {dim}));
auto shape = context.mark_node(std::make_shared<opset11::ShapeOf>(input_tensor));
auto k_values_node = context.mark_node(std::make_shared<opset11::Gather>(shape, dim_axis, zero_axis));
auto k_values = context.mark_node(std::make_shared<opset11::Squeeze>(k_values_node));
auto topk = context.mark_node(std::make_shared<opset11::TopK>(input_tensor,
k_values,
dim,
mode,
ov::op::TopKSortType::SORT_VALUES,
element::i64));
auto indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(topk->output(1), element::i64));
return {context.mark_node(make_list_construct(OutputVector({topk->output(0), indices})))};
};

OutputVector translate_argsort(const NodeContext& context) {
auto sort = translate_sort(context);
return {sort[1]};
Expand Down
34 changes: 34 additions & 0 deletions src/frontends/pytorch/src/op/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "openvino/op/topk.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/op/convert.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -41,6 +42,39 @@ OutputVector translate_topk(const NodeContext& context) {
return {topk->output(0), indices};
};

OutputVector translate_topk_fx(const NodeContext& context) {
// aten.topk.default(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> Tuple[Tensor, Tensor]
num_inputs_check(context, 2, 5);
const auto input_tensor = context.get_input(0);
auto k = context.get_input(1);
int64_t axis{-1};
bool largest = true;
bool sorted = true;
auto mode = TopKMode::MIN;
auto sort = TopKSortType::NONE;

if (!context.input_is_none(2)) {
axis = context.const_input<int64_t>(2);
}
if (!context.input_is_none(3)) {
largest = context.const_input<bool>(3);
}
if (!context.input_is_none(4)) {
sorted = context.const_input<bool>(4);
}
if (largest) {
mode = TopKMode::MAX;
}
if (sorted) {
sort = TopKSortType::SORT_VALUES;
}

auto topk = context.mark_node(std::make_shared<v3::TopK>(input_tensor, k, axis, mode, sort));
auto indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));

return {context.mark_node(make_list_construct(OutputVector({topk->output(0), indices})))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
16 changes: 12 additions & 4 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,24 @@ OP_CONVERTER(translate_quantized_convnd);
OP_CONVERTER(translate_quantized_convnd_relu);
OP_CONVERTER(translate_quantized_linear);
// Torch FX Translations
OP_CONVERTER(translate_any_fx);
OP_CONVERTER(translate_arange_fx);
OP_CONVERTER(translate_batch_norm_fx);
OP_CONVERTER(translate_cat_fx);
OP_CONVERTER(translate_chunk_fx);
OP_CONVERTER(translate_expand_fx);
OP_CONVERTER(translate_group_norm_fx);
OP_CONVERTER(translate_index_fx);
OP_CONVERTER(translate_isinf_fx);
OP_CONVERTER(translate_isnan_fx);
OP_CONVERTER(translate_layer_norm_fx);
OP_CONVERTER(translate_max_poolnd_fx);
OP_CONVERTER(translate_slice_fx);
OP_CONVERTER(translate_softmax_fx);
OP_CONVERTER(translate_sort_fx);
OP_CONVERTER(translate_split_with_sizes_fx);
OP_CONVERTER(translate_stack_fx);
OP_CONVERTER(translate_topk_fx);
OP_CONVERTER(translate_transpose_fx);

} // namespace op
Expand Down Expand Up @@ -501,6 +506,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.add_.Tensor", op::translate_add},
{"aten.addmm.default", op::translate_addmm},
{"aten.all.default", op::translate_all},
{"aten.any.default", op::translate_any_fx},
{"aten.arange.start", op::translate_arange_fx},
{"aten.arange.start_step", op::translate_arange_fx},
{"aten.arange.default", op::translate_arange_fx},
Expand All @@ -525,7 +531,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.eq.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten.exp.default", op::translate_1to1_match_1_inputs<opset10::Exp>},
{"aten.expand.default", op::translate_expand_fx},
{"aten.fill_.Tensor", op::inplace_op<op::translate_fill_>},
{"aten.fill_.Tensor", op::inplace_op<op::translate_fill_>},
{"aten.full.default", op::translate_full},
{"aten.gather.default", op::translate_gather},
{"aten.gelu.default", op::translate_gelu},
Expand All @@ -534,6 +540,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.hardswish_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
{"aten.hardtanh_.default", op::inplace_op<op::translate_hardtanh>},
{"aten.index.Tensor", op::translate_index_fx},
{"aten.isinf.default", op::translate_isinf_fx},
{"aten.isnan.default", op::translate_isnan_fx},
{"aten.le.Scalar", op::translate_1to1_match_2_inputs_align_types<opset10::LessEqual>},
{"aten.lift_fresh_copy.default", op::skip_node},
{"aten.linalg_vector_norm.default", op::translate_linalg_vector_norm},
Expand Down Expand Up @@ -565,14 +573,14 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
{"aten.rsqrt.default", op::translate_rsqrt},
{"aten.rsub.Scalar", op::translate_rsub},
{"aten.scatter.src", op::translate_scatter},
{"aten.scatter.src", op::translate_scatter},
{"aten.select.int", op::translate_select},
{"aten.sigmoid.default", op::translate_1to1_match_1_inputs<opset10::Sigmoid>},
{"aten.silu.default", op::translate_1to1_match_1_inputs<opset10::Swish>},
{"aten.silu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Swish>>},
{"aten.sin.default", op::translate_1to1_match_1_inputs<opset10::Sin>},
{"aten.slice.Tensor", op::translate_slice_fx},
{"aten.sort.default", op::translate_sort},
{"aten.sort.default", op::translate_sort_fx},
{"aten.split.Tensor", op::translate_chunk_fx},
{"aten.split_with_sizes.default", op::translate_split_with_sizes_fx},
{"aten.stack.default", op::translate_stack_fx},
Expand All @@ -581,7 +589,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.sub.Tensor", op::translate_sub},
{"aten.t.default", op::translate_t},
{"aten.tanh.default", op::translate_1to1_match_1_inputs<opset10::Tanh>},
{"aten.topk.default", op::translate_topk},
{"aten.topk.default", op::translate_topk_fx},
{"aten.transpose.int", op::translate_transpose},
{"aten.tril.default", op::translate_tril},
{"aten.unsqueeze.default", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
Expand Down

0 comments on commit bf4b346

Please sign in to comment.