Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch Compile - New Op Support #23310

Merged
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
18212e1
New ops added to torch.compile support list
cavusmustafa Feb 28, 2024
b8c3575
Additional ops for NetVLad and ALIKE
cavusmustafa Feb 28, 2024
73b25f7
Additional op support for ChatGLM2
cavusmustafa Feb 28, 2024
1d1093b
PTFE input_model constructor output access fix
cavusmustafa Feb 28, 2024
25990fd
Removed bitwise_not fx version
cavusmustafa Feb 29, 2024
4b1f4f1
Additional op support for TorchFX
cavusmustafa Mar 1, 2024
eb8c2a0
Stack translation fix for TorchFX
cavusmustafa Mar 1, 2024
d9f792d
TorchFX unit tests for Div and Elu
cavusmustafa Mar 5, 2024
8365227
TorchFX unit test update: full, comparison, glu
cavusmustafa Mar 5, 2024
f4ec62e
TorchFX unit tests: grip_sample_2d, hardswish, hardtanh
cavusmustafa Mar 5, 2024
27f8e66
TorchFX unit tests: index_select, unary_ops, isinf, isnan
cavusmustafa Mar 5, 2024
da180f6
TorchFX: Additional op unit tests
cavusmustafa Mar 6, 2024
d6d8121
TorchFX: Additonal unit tests
cavusmustafa Mar 6, 2024
894eecf
Code style fix src/frontends/pytorch/src/op/cat.cpp
cavusmustafa Mar 6, 2024
2f2f7b9
Code style fix src/frontends/pytorch/src/op/any.cpp
cavusmustafa Mar 6, 2024
f1b546c
Code style fix src/frontends/pytorch/src/op/any.cpp
cavusmustafa Mar 6, 2024
04f6423
Code style fix src/frontends/pytorch/src/op/isinf.cpp
cavusmustafa Mar 6, 2024
551f7d7
Code style fix src/frontends/pytorch/src/op/isinf.cpp
cavusmustafa Mar 6, 2024
2d64021
Code style fix src/frontends/pytorch/src/op/isnan.cpp
cavusmustafa Mar 6, 2024
56d7908
Code style fix src/frontends/pytorch/src/op/isnan.cpp
cavusmustafa Mar 6, 2024
c1ab831
Code style fix src/frontends/pytorch/src/op/topk.cpp
cavusmustafa Mar 6, 2024
cebbca1
Added embedding_bag and fixed unbind int
ynimmaga Mar 6, 2024
ebd8c80
Code style fix src/frontends/pytorch/src/op/split.cpp
cavusmustafa Mar 6, 2024
207c25d
TorchFX: Unit test enabled for embedding_bag
cavusmustafa Mar 7, 2024
19fc3c1
TorchFX: bitwise_right_shift is temporarily removed
cavusmustafa Mar 7, 2024
816b2b8
Removed unnecessary include lines from bitwise translation
cavusmustafa Mar 7, 2024
98ce772
IsNan and IsInf translations are converted to 1to1
cavusmustafa Mar 7, 2024
12bb8d9
Removed prim ops and unused torchvision ops
cavusmustafa Mar 7, 2024
c23809a
import placement update for select_scatter unit test
cavusmustafa Mar 7, 2024
53490df
Removed test_log.py file
cavusmustafa Mar 7, 2024
c7fbcb7
Bugfix src/frontends/pytorch/src/op/sort.cpp
cavusmustafa Mar 8, 2024
1868d7e
var_correction_fx translation removed
cavusmustafa Mar 8, 2024
af3f54a
TorchFX: op list sorting
cavusmustafa Mar 8, 2024
e912a2f
Revert change for tests/layer_tests/pytorch_tests/test_batch_norm.py
cavusmustafa Mar 8, 2024
82a4389
Revert the change for tests/layer_tests/pytorch_tests/test_flip.py
cavusmustafa Mar 8, 2024
da637e5
Fixed mistakenly removed line in test_masked_fill.py
cavusmustafa Mar 8, 2024
e240e33
Typo fix in test_torch_decoder.py
cavusmustafa Mar 8, 2024
d188fba
Common sort translation for FX and TS
cavusmustafa Mar 8, 2024
6f73d19
hardswish unit test moved into unary_ops
cavusmustafa Mar 8, 2024
5bcd2c3
new line fix in var_mean.cpp
cavusmustafa Mar 8, 2024
ea90465
TorchFX: IsFinite op support
cavusmustafa Mar 8, 2024
3bc2b47
aten.hardtanh unit test support for multiple data types and input shapes
cavusmustafa Mar 8, 2024
4b7ce29
TorchFX: aten.any translation uptade
cavusmustafa Mar 11, 2024
2f025a5
Merge branch 'master' into torch_compile/new_op_support_3
cavusmustafa Mar 11, 2024
c1544e1
Simplified aten.any translation
cavusmustafa Mar 11, 2024
6e49c99
TorchFX: Support for aten.any.dim
cavusmustafa Mar 12, 2024
3f33267
TorchFX: Removed converting to i64 in sort translation
cavusmustafa Mar 12, 2024
6036867
Sort translation formatting fix
cavusmustafa Mar 12, 2024
6e344e5
New unit tests only enabled for FX backend
cavusmustafa Mar 13, 2024
8652d61
TorchFX: split unit test temporarily disabled
cavusmustafa Mar 13, 2024
030aba4
scale support for elu translation & typo fix
cavusmustafa Mar 14, 2024
5aaf117
Removed bitwise_right_shift for FX
cavusmustafa Mar 14, 2024
c59c716
Merge branch 'master' into torch_compile/new_op_support_3
cavusmustafa Mar 14, 2024
51deae6
Removed unused include lines in elu
cavusmustafa Mar 14, 2024
cb68bbe
Code formatting fix
cavusmustafa Mar 14, 2024
2e1becf
TorchFX: Log sigmoid fix
cavusmustafa Mar 18, 2024
5177557
Update src/frontends/pytorch/src/op/any.cpp
cavusmustafa Mar 19, 2024
ae27838
Update tests/layer_tests/pytorch_tests/test_unary_ops.py
cavusmustafa Mar 19, 2024
e4d1f0e
Merge branch 'master' into torch_compile/new_op_support_3
cavusmustafa Mar 19, 2024
de5e913
Any translation fix
cavusmustafa Mar 19, 2024
4f28e19
Enable TorchFX unit test for new_ones
cavusmustafa Mar 19, 2024
d7c50db
argmax_argmin translation using stable=true
cavusmustafa Mar 19, 2024
d8fcf9e
Merge branch 'master' into torch_compile/new_op_support_3
cavusmustafa Mar 19, 2024
bffe432
Code formatting fix
cavusmustafa Mar 19, 2024
4c25f5a
Code formatting fix
cavusmustafa Mar 19, 2024
25a5380
Code formatting fix
cavusmustafa Mar 19, 2024
611b61e
Code formatting fix
cavusmustafa Mar 19, 2024
aab368c
Revert removing convert from argmax_argmin translation
cavusmustafa Mar 19, 2024
e06489c
Code formatting fix
cavusmustafa Mar 19, 2024
01e1d85
Code formatting fix
cavusmustafa Mar 19, 2024
0b47a86
argmax_argmin topk fix
cavusmustafa Mar 20, 2024
adaad2d
Merge branch 'master' into torch_compile/new_op_support_3
cavusmustafa Mar 20, 2024
86bf932
TorchFX: revert changes to use new new_ones, new_full, and ones trans…
cavusmustafa Mar 20, 2024
f10afed
Merge branch 'master' into torch_compile/new_op_support_3
cavusmustafa Mar 20, 2024
20d2afa
Added div_ and rand ops
suryasidd Mar 20, 2024
293493c
log_sigmoid fix for FX and TS
cavusmustafa Mar 20, 2024
745a451
Added missing div_ to op_table
suryasidd Mar 20, 2024
92af4e2
Merge branch 'master' into torch_compile/new_op_support_3
cavusmustafa Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,104 +29,217 @@ class OperatorSupport(OperatorSupport):
def __init__(self, options):
support_dict = {
"_operator.getitem": None,
"torch.ops.aten._adaptive_avg_pool1d.default": None,
"torch.ops.aten._adaptive_avg_pool2d.default": None,
"torch.ops.aten._adaptive_avg_pool3d.default": None,
"torch.ops.aten._convolution.default": None,
"torch.ops.aten._embedding_bag.default": None,
"torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default": None,
"torch.ops.aten._local_scalar_dense.default": None,
"torch.ops.aten._log_softmax.default": None,
"torch.ops.aten._native_batch_norm_legit.default": None,
"torch.ops.aten._native_batch_norm_legit.no_stats": None,
"torch.ops.aten._native_batch_norm_legit_functional.default": None,
"torch.ops.aten._native_batch_norm_legit_no_training.default": None,
"torch.ops.aten._scaled_dot_product_flash_attention.default": None,
"torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default": None,
"torch.ops.aten._softmax.default": None,
"torch.ops.aten._to_copy.default": None,
"torch.ops.aten._unsafe_view.default": None,
"torch.ops.aten._unsafe_view.default": None,
"torch.ops.aten.abs.default": None,
"torch.ops.aten.acos.default": None,
"torch.ops.aten.acosh.default": None,
"torch.ops.aten.adaptive_max_pool1d.default": None,
"torch.ops.aten.adaptive_max_pool2d.default": None,
"torch.ops.aten.adaptive_max_pool3d.default": None,
"torch.ops.aten.add.Scalar": None,
"torch.ops.aten.add.Tensor": None,
"torch.ops.aten.add_.Tensor": None,
"torch.ops.aten.addcmul.default": None,
"torch.ops.aten.addmm.default": None,
"torch.ops.aten.alias.default": None,
"torch.ops.aten.all.default": None,
"torch.ops.aten.amax.default": None,
"torch.ops.aten.arange.start": None,
"torch.ops.aten.amin.default": None,
"torch.ops.aten.any.default": None,
"torch.ops.aten.any.dim": None,
"torch.ops.aten.arange.default": None,
"torch.ops.aten.arange.start": None,
"torch.ops.aten.arange.start_step": None,
"torch.ops.aten.argmax.default": None,
"torch.ops.aten.argmin.default": None,
"torch.ops.aten.as_strided.default": None,
"torch.ops.aten.asin.default": None,
"torch.ops.aten.asinh.default": None,
"torch.ops.aten.asinh.default": None,
"torch.ops.aten.atanh.default": None,
"torch.ops.aten.avg_pool2d.default": None,
"torch.ops.aten.avg_pool3d.default": None,
"torch.ops.aten.baddbmm.default": None,
"torch.ops.aten.bitwise_and.Tensor": None,
"torch.ops.aten.bitwise_not.default": None,
"torch.ops.aten.bitwise_or.Tensor": None,
"torch.ops.aten.bitwise_xor.Tensor": None,
"torch.ops.aten.bmm.default": None,
"torch.ops.aten.cat.default": None,
"torch.ops.aten.ceil.default": None,
"torch.ops.aten.clamp.default": None,
"torch.ops.aten.clamp_max.default": None,
"torch.ops.aten.clamp_max.Tensor": None,
"torch.ops.aten.clamp_min.default": None,
"torch.ops.aten.clamp_min.Tensor": None,
"torch.ops.aten.clone.default": None,
"torch.ops.aten.constant_pad_nd.default": None,
"torch.ops.aten.convolution.default": None,
"torch.ops.aten.copy.default": None,
"torch.ops.aten.copy_.default": None,
"torch.ops.aten.cos.default": None,
"torch.ops.aten.cosh.default": None,
"torch.ops.aten.cumsum.default": None,
"torch.ops.aten.detach.default": None,
"torch.ops.aten.detach_.default": None,
"torch.ops.aten.div.Scalar": None,
"torch.ops.aten.div.Tensor": None,
"torch.ops.aten.div.Tensor_mode": None,
"torch.ops.aten.div_.Tensor": None,
"torch.ops.aten.elu.default": None,
"torch.ops.aten.elu_.default": None,
"torch.ops.aten.embedding.default": None,
"torch.ops.aten.empty.memory_format": None,
"torch.ops.aten.erf.default": None,
"torch.ops.aten.eq.Scalar": None,
"torch.ops.aten.eq.Tensor": None,
"torch.ops.aten.erf.default": None,
"torch.ops.aten.exp.default": None,
"torch.ops.aten.expand.default": None,
"torch.ops.aten.fake_quantize_per_channel_affine_cachemask.default": None,
"torch.ops.aten.fill.Scalar": None,
"torch.ops.aten.fill_.Scalar": None,
"torch.ops.aten.fill.Tensor": None,
"torch.ops.aten.fill_.Tensor": None,
"torch.ops.aten.flip.default": None,
"torch.ops.aten.floor.default": None,
"torch.ops.aten.floor.default": None,
"torch.ops.aten.fmod.Scalar": None,
"torch.ops.aten.fmod.Tensor": None,
"torch.ops.aten.full.default": None,
"torch.ops.aten.full.names": None,
"torch.ops.aten.full_like.default": None,
"torch.ops.aten.gather.default": None,
"torch.ops.aten.ge.Scalar": None,
"torch.ops.aten.ge.Tensor": None,
"torch.ops.aten.gelu.default": None,
"torch.ops.aten.glu.default": None,
"torch.ops.aten.grid_sampler_2d.default": None,
"torch.ops.aten.gt.Scalar": None,
"torch.ops.aten.gt.Tensor": None,
"torch.ops.aten.hardsigmoid.default": None,
"torch.ops.aten.hardswish.default": None,
"torch.ops.aten.hardswish_.default": None,
"torch.ops.aten.hardtanh.default": None,
"torch.ops.aten.hardtanh_.default": None,
"torch.ops.aten.index.Tensor": None,
"torch.ops.aten.index_select.default": None,
"torch.ops.aten.isfinite.default": None,
"torch.ops.aten.isinf.default": None,
"torch.ops.aten.isnan.default": None,
"torch.ops.aten.le.Scalar": None,
"torch.ops.aten.le.Tensor": None,
"torch.ops.aten.leaky_relu.default": None,
"torch.ops.aten.leaky_relu_.default": None,
"torch.ops.aten.lift_fresh_copy.default": None,
"torch.ops.aten.linalg_vector_norm.default": None,
"torch.ops.aten.lt.Tensor": None,
"torch.ops.aten.log.default": None,
"torch.ops.aten.log_sigmoid_forward.default": None,
"torch.ops.aten.log10.default": None,
"torch.ops.aten.log1p.default": None,
"torch.ops.aten.log2.default": None,
"torch.ops.aten.logical_not.default": None,
"torch.ops.aten.logsumexp.default": None,
"torch.ops.aten.masked_fill_.Scalar": None,
"torch.ops.aten.lt.Scalar": None,
"torch.ops.aten.lt.Tensor": None,
"torch.ops.aten.masked_fill.Scalar": None,
"torch.ops.aten.masked_fill.Tensor": None,
"torch.ops.aten.masked_fill_.Scalar": None,
"torch.ops.aten.masked_fill_.Tensor": None,
"torch.ops.aten.max.default": None,
"torch.ops.aten.max.dim": None,
"torch.ops.aten.max_pool2d_with_indices.default": None,
"torch.ops.aten.max_pool3d_with_indices.default": None,
"torch.ops.aten.maximum.default": None,
"torch.ops.aten.mean.default": None,
"torch.ops.aten.mean.dim": None,
"torch.ops.aten.min.default": None,
"torch.ops.aten.min.dim": None,
"torch.ops.aten.minimum.default": None,
"torch.ops.aten.mm.default": None,
"torch.ops.aten.mul.Scalar": None,
"torch.ops.aten.mul.Tensor": None,
"torch.ops.aten.native_batch_norm.default": None,
"torch.ops.aten._native_batch_norm_legit.default": None,
"torch.ops.aten._native_batch_norm_legit_no_training.default": None,
"torch.ops.aten.native_dropout.default": None,
"torch.ops.aten.native_group_norm.default": None,
"torch.ops.aten.native_layer_norm.default": None,
"torch.ops.aten.new_full.default": None,
"torch.ops.aten.ne.Scalar": None,
"torch.ops.aten.ne.Tensor": None,
"torch.ops.aten.neg.default": None,
"torch.ops.aten.new_full.default": None,
"torch.ops.aten.new_ones.default": None,
"torch.ops.aten.new_zeros.default": None,
"torch.ops.aten.ones.default": None,
"torch.ops.aten.permute.default": None,
"torch.ops.aten.pow.Scalar": None,
"torch.ops.aten.pow.Tensor_Scalar": None,
"torch.ops.aten.pow.Tensor_Tensor": None,
"torch.ops.aten.rand.default": None,
"torch.ops.aten.reciprocal.default": None,
"torch.ops.aten.relu.default": None,
"torch.ops.aten.relu_.default": None,
"torch.ops.aten.repeat.default": None,
"torch.ops.aten.roll.default": None,
"torch.ops.aten.rsqrt.default": None,
"torch.ops.aten.rsub.Scalar": None,
"torch.ops.aten._scaled_dot_product_flash_attention.default": None,
"torch.ops.aten.rsub.Tensor": None,
"torch.ops.aten.scalar_tensor.default": None,
"torch.ops.aten.scatter.src": None,
"torch.ops.aten.scatter.value": None,
"torch.ops.aten.select.int": None,
"torch.ops.aten.select_scatter.default": None,
"torch.ops.aten.sigmoid.default": None,
"torch.ops.aten.sign.default": None,
"torch.ops.aten.silu.default": None,
"torch.ops.aten.silu_.default": None,
"torch.ops.aten.sin.default": None,
"torch.ops.aten.sinh.default": None,
"torch.ops.aten.slice.Tensor": None,
"torch.ops.aten.slice_scatter.default": None,
"torch.ops.aten.sort.default": None,
"torch.ops.aten.split.Tensor": None,
"torch.ops.aten.split_with_sizes.default": None,
"torch.ops.aten.sqrt.default": None,
"torch.ops.aten.squeeze.dim": None,
"torch.ops.aten.squeeze.dims": None,
"torch.ops.aten.stack.default": None,
"torch.ops.aten.sub.default": None,
"torch.ops.aten.sub.Tensor": None,
"torch.ops.aten.sum.default": None,
"torch.ops.aten.sum.dim_IntList": None,
"torch.ops.aten.t.default": None,
"torch.ops.aten.tan.default": None,
"torch.ops.aten.tanh.default": None,
"torch.ops.aten.topk.default": None,
"torch.ops.aten.transpose.int": None,
"torch.ops.aten.tril.default": None,
"torch.ops.aten.tril_.default": None,
"torch.ops.aten.unbind.int": None,
"torch.ops.aten.unfold.default": None,
"torch.ops.aten.unsqueeze.default": None,
"torch.ops.aten.upsample_nearest2d.default": None,
"torch.ops.aten.var.correction": None,
"torch.ops.aten.var_mean.correction": None,
"torch.ops.aten.view.default": None,
"torch.ops.aten.where.self": None,
"torch.ops.aten.zeros_like.default": None,
"torch.ops.torchvision.deform_conv2d.default": None,
"torch.ops.torchvision.roi_align.default": None,
}

for op in _get_disabled_ops(options):
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/input_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ InputModel::InputModel(const std::shared_ptr<TorchDecoder>& model_decoder) : m_m
const auto& outputs = m_model_decoder->outputs();
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_place = std::make_shared<pytorch::Place>(*this, outputs[i]);
m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast<frontend::Place>(out_place));
m_name_to_place.emplace(std::to_string(outputs[i]), std::dynamic_pointer_cast<frontend::Place>(out_place));
for (const auto& name : out_place->get_names()) {
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(out_place));
}
Expand Down
38 changes: 38 additions & 0 deletions src/frontends/pytorch/src/op/any.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/not_equal.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_logical_or.hpp"
#include "openvino/op/reshape.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_any_fx(const NodeContext& context) {
num_inputs_check(context, 1, 3);
auto x = context.get_input(0);

Output<Node> dims;
if (!context.input_is_none(1)) {
dims = context.get_input(1);
} else {
dims = get_axes_range(context, 0);
}
bool keep_dims = false;
if (!context.input_is_none(2))
keep_dims = context.const_input<bool>(2);
auto any = context.mark_node(std::make_shared<ov::op::v1::ReduceLogicalOr>(x, dims, keep_dims));
return {any};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
6 changes: 4 additions & 2 deletions src/frontends/pytorch/src/op/argmax_argmin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ OutputVector create_argmax_argmin_op(const NodeContext& context, TopKMode mode)
}
if (!context.input_is_none(1)) {
auto axis = context.const_input<int64_t>(1);
auto topk = context.mark_node(std::make_shared<v3::TopK>(input, k, axis, mode, TopKSortType::NONE));
auto topk = context.mark_node(
std::make_shared<v11::TopK>(input, k, axis, mode, TopKSortType::SORT_VALUES, element::i32, true));
indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
if (!keep_dims) {
auto axis_to_remove = context.mark_node(v0::Constant::create(element::i32, Shape{}, {axis}));
Expand All @@ -41,7 +42,8 @@ OutputVector create_argmax_argmin_op(const NodeContext& context, TopKMode mode)
int64_t axis = 0;
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto flatten_input = context.mark_node(std::make_shared<v1::Reshape>(input, minus_one, false));
auto topk = context.mark_node(std::make_shared<v3::TopK>(flatten_input, k, axis, mode, TopKSortType::NONE));
auto topk = context.mark_node(
std::make_shared<v11::TopK>(flatten_input, k, axis, mode, TopKSortType::SORT_VALUES, element::i32, true));
indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
if (keep_dims) {
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
Expand Down
16 changes: 10 additions & 6 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,24 @@ OutputVector translate_quantized_cat(const NodeContext& context) {
};

OutputVector translate_stack_fx(const NodeContext& context) {
num_inputs_check(context, 2, context.get_input_size());
num_inputs_check(context, 1, context.get_input_size());
auto dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
std::deque<Output<Node>> list_elems;
auto num_elements = context.get_input_size();
if (num_elements > 2)
num_elements = num_elements - 1;
for (size_t i = 0; i < num_elements; i++) {
for (size_t i = 0; i < num_elements - 1; i++) {
auto stack_input =
context.mark_node(std::make_shared<v0::Unsqueeze>(context.get_input(static_cast<int>(i)), dim));
list_elems.push_back(stack_input);
}
int64_t axis = 0;
if (context.get_input_size() > 2)
axis = context.const_input<int64_t>(context.get_input_size() - 1);
if (!context.get_input_type(num_elements - 1).is<type::List>()) {
// axis can be not present and that means that last input will have List type
axis = context.const_input<int64_t>(num_elements - 1);
} else {
auto stack_input = context.mark_node(
std::make_shared<v0::Unsqueeze>(context.get_input(static_cast<int>(num_elements - 1)), dim));
list_elems.push_back(stack_input);
}
return translate_cat_common(context, list_elems, axis, true);
}

Expand Down
11 changes: 11 additions & 0 deletions src/frontends/pytorch/src/op/div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ OutputVector translate_div_fx(const NodeContext& context) {
return translate_div_common(context, x, y, rounding_mode, false);
};

OutputVector translate_div_fx_(const NodeContext& context) {
suryasidd marked this conversation as resolved.
Show resolved Hide resolved
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto y = context.get_input(1);
std::string rounding_mode = "";
if (context.has_attribute("rounding_mode")) {
rounding_mode = context.get_attribute<std::string>("rounding_mode");
}
return translate_div_common(context, x, y, rounding_mode, true);
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
20 changes: 16 additions & 4 deletions src/frontends/pytorch/src/op/embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_embedding_bag(const NodeContext& context) {
OutputVector translate_embedding_bag_common(const NodeContext& context) {
// aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False,
// per_sample_weights=None, include_last_offset=False, padding_idx=None)
num_inputs_check(context, 9, 9);
// we have only EmbeddingBagSum case support, check it before translation
auto mode = context.const_input<int64_t>(4);
PYTORCH_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation");
Expand All @@ -43,7 +42,9 @@ OutputVector translate_embedding_bag(const NodeContext& context) {
// with offsets case
auto offsets = context.get_input(2);
offsets = context.mark_node(std::make_shared<ov::op::v0::Convert>(offsets, element::i32));
auto include_last_offset = context.const_input<bool>(7);
bool include_last_offset = false;
if (!context.input_is_none(7))
include_last_offset = context.const_input<bool>(7);
PYTORCH_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported");
// no per_sample_wights
if (context.input_is_none(6)) {
Expand All @@ -63,7 +64,18 @@ OutputVector translate_embedding_bag(const NodeContext& context) {
return {result, zero, zero, zero};
};

OutputVector translate_embedding_bag(const NodeContext& context) {
num_inputs_check(context, 9, 9);
return translate_embedding_bag_common(context);
}

OutputVector translate_embedding_bag_fx(const NodeContext& context) {
num_inputs_check(context, 7, 9);
ov::OutputVector output = translate_embedding_bag_common(context);
return {context.mark_node(make_list_construct(output))};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
Loading
Loading