Skip to content

Commit

Permalink
Additional op support for ChatGLM2
Browse files Browse the repository at this point in the history
  • Loading branch information
cavusmustafa committed Feb 28, 2024
1 parent 9f4bdae commit 272e0f2
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def __init__(self, options):
"torch.ops.aten.addcmul.default": None,
"torch.ops.aten.addmm.default": None,
"torch.ops.aten.alias.default": None,
"torch.ops.aten.all.default": None,
"torch.ops.aten.amax.default": None,
"torch.ops.aten.amin.default": None,
"torch.ops.aten.any.default": None,
"torch.ops.aten.arange.default": None,
"torch.ops.aten.arange.start": None,
"torch.ops.aten.arange.start_step": None,
Expand All @@ -75,6 +77,7 @@ def __init__(self, options):
"torch.ops.aten.bitwise_and.Tensor": None,
"torch.ops.aten.bitwise_not.default": None,
"torch.ops.aten.bitwise_or.Tensor": None,
"torch.ops.aten.bitwise_right_shift.Tensor": None,
"torch.ops.aten.bitwise_xor.Tensor": None,
"torch.ops.aten.bmm.default": None,
"torch.ops.aten.cat.default": None,
Expand All @@ -93,6 +96,7 @@ def __init__(self, options):
"torch.ops.aten.cosh.default": None,
"torch.ops.aten.cumsum.default": None,
"torch.ops.aten.detach.default": None,
"torch.ops.aten.detach_.default": None,
"torch.ops.aten.div.Scalar": None,
"torch.ops.aten.div.Tensor": None,
"torch.ops.aten.div.Tensor_mode": None,
Expand All @@ -108,6 +112,7 @@ def __init__(self, options):
"torch.ops.aten.fake_quantize_per_channel_affine_cachemask.default": None,
"torch.ops.aten.fill.Scalar": None,
"torch.ops.aten.fill.Tensor": None,
"torch.ops.aten.fill_.Tensor": None,
"torch.ops.aten.flip.default": None,
"torch.ops.aten.floor.default": None,
"torch.ops.aten.floor.default": None,
Expand All @@ -131,6 +136,8 @@ def __init__(self, options):
"torch.ops.aten.hardtanh_.default": None,
"torch.ops.aten.index.Tensor": None,
"torch.ops.aten.index_select.default": None,
"torch.ops.aten.isinf.default": None,
"torch.ops.aten.isnan.default": None,
"torch.ops.aten.le.Scalar": None,
"torch.ops.aten.le.Tensor": None,
"torch.ops.aten.leaky_relu.default": None,
Expand All @@ -142,10 +149,12 @@ def __init__(self, options):
"torch.ops.aten.log10.default": None,
"torch.ops.aten.log1p.default": None,
"torch.ops.aten.log2.default": None,
"torch.ops.aten.logical_not.default": None,
"torch.ops.aten.logsumexp.default": None,
"torch.ops.aten.lt.Scalar": None,
"torch.ops.aten.lt.Tensor": None,
"torch.ops.aten.masked_fill.Tensor": None,
"torch.ops.aten.masked_fill.Scalar": None,
"torch.ops.aten.masked_fill_.Scalar": None,
"torch.ops.aten.max.default": None,
"torch.ops.aten.max.dim": None,
Expand All @@ -169,6 +178,7 @@ def __init__(self, options):
"torch.ops.aten.neg.default": None,
"torch.ops.aten.new_full.default": None,
"torch.ops.aten.new_ones.default": None,
"torch.ops.aten.ones.default": None,
"torch.ops.aten.permute.default": None,
"torch.ops.aten.pow.Scalar": None,
"torch.ops.aten.pow.Tensor_Scalar": None,
Expand All @@ -182,6 +192,7 @@ def __init__(self, options):
"torch.ops.aten.rsub.Scalar": None,
"torch.ops.aten.rsub.Tensor": None,
"torch.ops.aten.scalar_tensor.default": None,
"torch.ops.aten.scatter.src": None,
"torch.ops.aten.scatter.value": None,
"torch.ops.aten.select.int": None,
"torch.ops.aten.select_scatter.default": None,
Expand All @@ -193,6 +204,7 @@ def __init__(self, options):
"torch.ops.aten.sinh.default": None,
"torch.ops.aten.slice.Tensor": None,
"torch.ops.aten.slice_scatter.default": None,
"torch.ops.aten.sort.default": None,
"torch.ops.aten.split.Tensor": None,
"torch.ops.aten.split_with_sizes.default": None,
"torch.ops.aten.sqrt.default": None,
Expand All @@ -206,7 +218,10 @@ def __init__(self, options):
"torch.ops.aten.t.default": None,
"torch.ops.aten.tan.default": None,
"torch.ops.aten.tanh.default": None,
"torch.ops.aten.topk.default": None,
"torch.ops.aten.transpose.int": None,
"torch.ops.aten.tril.default": None,
"torch.ops.aten.tril_.default": None,
"torch.ops.aten.unbind.int": None,
"torch.ops.aten.unfold.default": None,
"torch.ops.aten.unsqueeze.default": None,
Expand Down
32 changes: 32 additions & 0 deletions src/frontends/pytorch/src/op/any.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/reduce_logical_or.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/not_equal.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_any_fx(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
auto num_inputs = context.get_input_size();
bool keep_dims = false;
auto const_minus_one = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {-1}));
auto flatten_source = context.mark_node(std::make_shared<ov::op::v1::Reshape>(x, const_minus_one, false));
auto const_zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0}));
auto any = context.mark_node(std::make_shared<ov::op::v1::ReduceLogicalOr>(flatten_source, const_zero, keep_dims));
return {any};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
20 changes: 20 additions & 0 deletions src/frontends/pytorch/src/op/bitwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "openvino/op/bitwise_not.hpp"
#include "openvino/op/bitwise_or.hpp"
#include "openvino/op/bitwise_xor.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/logical_not.hpp"
#include "utils.hpp"

namespace ov {
Expand Down Expand Up @@ -60,6 +62,24 @@ OutputVector translate_bitwise_xor(const NodeContext& context) {
return {xor_x};
};

OutputVector translate_bitwise_not_fx(const NodeContext& context) {
num_inputs_check(context, 1, 2);
auto x = context.get_input(0);
if (x.get_element_type() != element::boolean) {
auto x_bool = context.mark_node(std::make_shared<ov::op::v0::Convert>(x, element::boolean));
auto not_x = context.mark_node(std::make_shared<ov::op::v1::LogicalNot>(x_bool));
if (!context.input_is_none(1)) {
context.mutate_input(1, not_x);
}
return {not_x};
}
auto not_x = context.mark_node(std::make_shared<ov::op::v1::LogicalNot>(x));
if (!context.input_is_none(1)) {
context.mutate_input(1, not_x);
}
return {not_x};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
12 changes: 9 additions & 3 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,22 @@ OutputVector translate_stack_fx(const NodeContext& context) {
auto dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
std::deque<Output<Node>> list_elems;
auto num_elements = context.get_input_size();
if (num_elements > 2)
num_elements = num_elements - 1;
for (size_t i = 0; i < num_elements; i++) {
for (size_t i = 0; i < num_elements - 1; i++) {
auto stack_input =
context.mark_node(std::make_shared<v0::Unsqueeze>(context.get_input(static_cast<int>(i)), dim));
list_elems.push_back(stack_input);
}
int64_t axis = 0;
if (context.get_input_size() > 2)
axis = context.const_input<int64_t>(context.get_input_size() - 1);
if (!context.get_input_type(context.get_input_size() - 1).is<type::List>()) {
// axis can be not present and that means that last input will have List type
axis = context.const_input<int64_t>(context.get_input_size() - 1);
} else {
auto stack_input =
context.mark_node(std::make_shared<v0::Unsqueeze>(context.get_input(static_cast<int>(context.get_input_size() - 1)), dim));
list_elems.push_back(stack_input);
}
return translate_cat_common(context, list_elems, axis, true);
}

Expand Down
25 changes: 25 additions & 0 deletions src/frontends/pytorch/src/op/isinf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/add.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/is_inf.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_isinf_fx(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto input = context.get_input(0);
return {context.mark_node(std::make_shared<ov::op::v10::IsInf>(input))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
25 changes: 25 additions & 0 deletions src/frontends/pytorch/src/op/isnan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/add.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/is_nan.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_isnan_fx(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto input = context.get_input(0);
return {context.mark_node(std::make_shared<ov::op::v10::IsNaN>(input))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
30 changes: 30 additions & 0 deletions src/frontends/pytorch/src/op/sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,36 @@ OutputVector translate_sort(const NodeContext& context) {
return topk->outputs();
};

OutputVector translate_sort_fx(const NodeContext& context) {
// aten.sort.default(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
num_inputs_check(context, 1, 3);
const auto input_tensor = context.get_input(0);
bool descending = false;
int64_t dim = -1;

if (!context.input_is_none(1)) {
dim = context.const_input<int64_t>(1);
}
if (!context.input_is_none(1)) {
descending = context.const_input<bool>(2);
}

auto mode = descending ? ov::op::TopKMode::MAX : ov::op::TopKMode::MIN;
auto zero_axis = context.mark_node(opset11::Constant::create(element::i32, Shape{1}, {0}));
auto dim_axis = context.mark_node(opset11::Constant::create(element::i64, Shape{1}, {dim}));
auto shape = context.mark_node(std::make_shared<opset11::ShapeOf>(input_tensor));
auto k_values_node = context.mark_node(std::make_shared<opset11::Gather>(shape, dim_axis, zero_axis));
auto k_values = context.mark_node(std::make_shared<opset11::Squeeze>(k_values_node));
auto topk = context.mark_node(std::make_shared<opset11::TopK>(input_tensor,
k_values,
dim,
mode,
ov::op::TopKSortType::SORT_VALUES,
element::i64));
auto indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(topk->output(1), element::i64));
return {context.mark_node(make_list_construct(OutputVector({topk->output(0), indices})))};
};

OutputVector translate_argsort(const NodeContext& context) {
auto sort = translate_sort(context);
return {sort[1]};
Expand Down
6 changes: 3 additions & 3 deletions src/frontends/pytorch/src/op/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ OutputVector translate_chunk_fx(const NodeContext& context) {

std::shared_ptr<ov::Node> chunk;
auto dim_val = context.const_input<int>(2);
auto shape = context.get_input(0).get_shape();
auto shape = context.get_input(0).get_partial_shape();
if (dim_val < 0) {
dim_val = static_cast<int>(shape.size()) + dim_val;
dim_val = static_cast<int>(shape.rank().get_length()) + dim_val;
}
int num_splits = static_cast<int>(shape[dim_val]) / num_chunks;
int num_splits = static_cast<int>(shape[dim_val].get_length()) / num_chunks;

chunk = context.mark_node(std::make_shared<v1::Split>(context.get_input(0), dim, num_splits));

Expand Down
34 changes: 34 additions & 0 deletions src/frontends/pytorch/src/op/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "openvino/op/topk.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/op/convert.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -41,6 +42,39 @@ OutputVector translate_topk(const NodeContext& context) {
return {topk->output(0), indices};
};

OutputVector translate_topk_fx(const NodeContext& context) {
// aten.topk.default(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> Tuple[Tensor, Tensor]
num_inputs_check(context, 2, 5);
const auto input_tensor = context.get_input(0);
auto k = context.get_input(1);
int64_t axis{-1};
bool largest = true;
bool sorted = true;
auto mode = TopKMode::MIN;
auto sort = TopKSortType::NONE;

if (!context.input_is_none(2)) {
axis = context.const_input<int64_t>(2);
}
if (!context.input_is_none(3)) {
largest = context.const_input<bool>(3);
}
if (!context.input_is_none(4)) {
sorted = context.const_input<bool>(4);
}
if (largest) {
mode = TopKMode::MAX;
}
if (sorted) {
sort = TopKSortType::SORT_VALUES;
}

auto topk = context.mark_node(std::make_shared<v3::TopK>(input_tensor, k, axis, mode, sort));
auto indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));

return {context.mark_node(make_list_construct(OutputVector({topk->output(0), indices})))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
7 changes: 6 additions & 1 deletion src/frontends/pytorch/src/op/var_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ OutputVector translate_var(const NodeContext& context) {
return {res[0]};
}

OutputVector translate_var_correction_fx(const NodeContext& context) {
auto res = translate_var_mean(context);
return {context.mark_node(make_list_construct(res))};
}

OutputVector translate_std(const NodeContext& context) {
auto res = translate_var_mean(context);
auto var = res[0];
Expand All @@ -160,4 +165,4 @@ OutputVector translate_std_mean(const NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
Loading

0 comments on commit 272e0f2

Please sign in to comment.