From 87bed732cb2f70965b29b93302ab0b108480e00c Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Tue, 27 Feb 2024 23:40:42 -0800 Subject: [PATCH 01/25] New ops added to torch.compile support list --- .../pytorch/torchdynamo/op_support.py | 104 ++++++++++++++++-- 1 file changed, 95 insertions(+), 9 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 0a7bebd8763215..9a233d86017ec3 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -29,100 +29,186 @@ class OperatorSupport(OperatorSupport): def __init__(self, options): support_dict = { "_operator.getitem": None, + "torch.ops.aten._adaptive_avg_pool1d.default": None, "torch.ops.aten._adaptive_avg_pool2d.default": None, + "torch.ops.aten._adaptive_avg_pool3d.default": None, + "torch.ops.aten._convolution.default": None, + "torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default": None, + "torch.ops.aten._local_scalar_dense.default": None, "torch.ops.aten._log_softmax.default": None, + "torch.ops.aten._native_batch_norm_legit.default": None, + "torch.ops.aten._native_batch_norm_legit.no_stats": None, + "torch.ops.aten._native_batch_norm_legit_functional.default": None, + "torch.ops.aten._native_batch_norm_legit_no_training.default": None, + "torch.ops.aten._scaled_dot_product_flash_attention.default": None, + "torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default": None, "torch.ops.aten._softmax.default": None, "torch.ops.aten._to_copy.default": None, "torch.ops.aten._unsafe_view.default": None, - "torch.ops.aten._unsafe_view.default": None, + "torch.ops.aten.abs.default": None, + "torch.ops.aten.acos.default": None, + "torch.ops.aten.acosh.default": None, + "torch.ops.aten.adaptive_max_pool1d.default": None, + "torch.ops.aten.adaptive_max_pool2d.default": None, + "torch.ops.aten.adaptive_max_pool3d.default": None, "torch.ops.aten.add.Scalar": None, "torch.ops.aten.add.Tensor": None, "torch.ops.aten.add_.Tensor": None, + "torch.ops.aten.addcmul.default": None, "torch.ops.aten.addmm.default": None, + "torch.ops.aten.alias.default": None, "torch.ops.aten.amax.default": None, - "torch.ops.aten.arange.start": None, + "torch.ops.aten.amin.default": None, "torch.ops.aten.arange.default": None, + "torch.ops.aten.arange.start": None, + "torch.ops.aten.arange.start_step": None, "torch.ops.aten.argmax.default": None, + "torch.ops.aten.argmin.default": None, + "torch.ops.aten.as_strided.default": None, + "torch.ops.aten.asin.default": None, + "torch.ops.aten.asinh.default": None, + "torch.ops.aten.asinh.default": None, + "torch.ops.aten.atanh.default": None, "torch.ops.aten.avg_pool2d.default": None, + "torch.ops.aten.avg_pool3d.default": None, "torch.ops.aten.baddbmm.default": None, "torch.ops.aten.bitwise_and.Tensor": None, + "torch.ops.aten.bitwise_not.default": None, + "torch.ops.aten.bitwise_or.Tensor": None, + "torch.ops.aten.bitwise_xor.Tensor": None, "torch.ops.aten.bmm.default": None, "torch.ops.aten.cat.default": None, + "torch.ops.aten.ceil.default": None, + "torch.ops.aten.clamp.default": None, + "torch.ops.aten.clamp_max.default": None, + "torch.ops.aten.clamp_max.Tensor": None, "torch.ops.aten.clamp_min.default": None, + "torch.ops.aten.clamp_min.Tensor": None, "torch.ops.aten.clone.default": None, + "torch.ops.aten.constant_pad_nd.default": None, "torch.ops.aten.convolution.default": None, + "torch.ops.aten.copy.default": None, "torch.ops.aten.copy_.default": None, "torch.ops.aten.cos.default": None, + "torch.ops.aten.cosh.default": None, "torch.ops.aten.cumsum.default": None, "torch.ops.aten.detach.default": None, "torch.ops.aten.div.Scalar": None, "torch.ops.aten.div.Tensor": None, + "torch.ops.aten.div.Tensor_mode": None, "torch.ops.aten.embedding.default": None, "torch.ops.aten.empty.memory_format": None, - "torch.ops.aten.erf.default": None, "torch.ops.aten.eq.Scalar": None, "torch.ops.aten.eq.Tensor": None, + "torch.ops.aten.erf.default": None, "torch.ops.aten.exp.default": None, "torch.ops.aten.expand.default": None, + "torch.ops.aten.fake_quantize_per_channel_affine_cachemask.default": None, "torch.ops.aten.fill.Scalar": None, + "torch.ops.aten.fill.Tensor": None, + "torch.ops.aten.flip.default": None, + "torch.ops.aten.floor.default": None, + "torch.ops.aten.floor.default": None, + "torch.ops.aten.fmod.Scalar": None, + "torch.ops.aten.fmod.Tensor": None, "torch.ops.aten.full.default": None, + "torch.ops.aten.full.names": None, + "torch.ops.aten.full_like.default": None, "torch.ops.aten.gather.default": None, + "torch.ops.aten.ge.Scalar": None, + "torch.ops.aten.ge.Tensor": None, "torch.ops.aten.gelu.default": None, + "torch.ops.aten.glu.default": None, "torch.ops.aten.gt.Scalar": None, + "torch.ops.aten.gt.Tensor": None, "torch.ops.aten.hardsigmoid.default": None, + "torch.ops.aten.hardswish.default": None, "torch.ops.aten.hardswish_.default": None, + "torch.ops.aten.hardtanh.default": None, "torch.ops.aten.hardtanh_.default": None, "torch.ops.aten.index.Tensor": None, + "torch.ops.aten.index_select.default": None, + "torch.ops.aten.le.Scalar": None, + "torch.ops.aten.le.Tensor": None, + "torch.ops.aten.leaky_relu.default": None, "torch.ops.aten.leaky_relu_.default": None, "torch.ops.aten.lift_fresh_copy.default": None, "torch.ops.aten.linalg_vector_norm.default": None, - "torch.ops.aten.lt.Tensor": None, "torch.ops.aten.log.default": None, "torch.ops.aten.log_sigmoid_forward.default": None, + "torch.ops.aten.log10.default": None, + "torch.ops.aten.log1p.default": None, + "torch.ops.aten.log2.default": None, "torch.ops.aten.logsumexp.default": None, - "torch.ops.aten.masked_fill_.Scalar": None, + "torch.ops.aten.lt.Scalar": None, + "torch.ops.aten.lt.Tensor": None, "torch.ops.aten.masked_fill.Tensor": None, + "torch.ops.aten.masked_fill_.Scalar": None, + "torch.ops.aten.max.default": None, "torch.ops.aten.max.dim": None, "torch.ops.aten.max_pool2d_with_indices.default": None, + "torch.ops.aten.max_pool3d_with_indices.default": None, + "torch.ops.aten.maximum.default": None, + "torch.ops.aten.mean.default": None, "torch.ops.aten.mean.dim": None, + "torch.ops.aten.min.default": None, + "torch.ops.aten.min.dim": None, + "torch.ops.aten.minimum.default": None, "torch.ops.aten.mm.default": None, "torch.ops.aten.mul.Scalar": None, "torch.ops.aten.mul.Tensor": None, "torch.ops.aten.native_batch_norm.default": None, - "torch.ops.aten._native_batch_norm_legit.default": None, - "torch.ops.aten._native_batch_norm_legit_no_training.default": None, + "torch.ops.aten.native_dropout.default": None, "torch.ops.aten.native_group_norm.default": None, "torch.ops.aten.native_layer_norm.default": None, - "torch.ops.aten.new_full.default": None, + "torch.ops.aten.ne.Scalar": None, + "torch.ops.aten.ne.Tensor": None, "torch.ops.aten.neg.default": None, + "torch.ops.aten.new_full.default": None, "torch.ops.aten.new_ones.default": None, "torch.ops.aten.permute.default": None, + "torch.ops.aten.pow.Scalar": None, "torch.ops.aten.pow.Tensor_Scalar": None, + "torch.ops.aten.pow.Tensor_Tensor": None, + "torch.ops.aten.reciprocal.default": None, "torch.ops.aten.relu.default": None, "torch.ops.aten.relu_.default": None, + "torch.ops.aten.repeat.default": None, + "torch.ops.aten.roll.default": None, "torch.ops.aten.rsqrt.default": None, "torch.ops.aten.rsub.Scalar": None, - "torch.ops.aten._scaled_dot_product_flash_attention.default": None, + "torch.ops.aten.rsub.Tensor": None, "torch.ops.aten.scalar_tensor.default": None, + "torch.ops.aten.scatter.value": None, "torch.ops.aten.select.int": None, + "torch.ops.aten.select_scatter.default": None, "torch.ops.aten.sigmoid.default": None, + "torch.ops.aten.sign.default": None, "torch.ops.aten.silu.default": None, "torch.ops.aten.silu_.default": None, "torch.ops.aten.sin.default": None, + "torch.ops.aten.sinh.default": None, "torch.ops.aten.slice.Tensor": None, + "torch.ops.aten.slice_scatter.default": None, "torch.ops.aten.split.Tensor": None, + "torch.ops.aten.split_with_sizes.default": None, + "torch.ops.aten.sqrt.default": None, "torch.ops.aten.squeeze.dim": None, "torch.ops.aten.squeeze.dims": None, "torch.ops.aten.stack.default": None, "torch.ops.aten.sub.default": None, "torch.ops.aten.sub.Tensor": None, + "torch.ops.aten.sum.deafult": None, "torch.ops.aten.sum.dim_IntList": None, "torch.ops.aten.t.default": None, + "torch.ops.aten.tan.default": None, "torch.ops.aten.tanh.default": None, "torch.ops.aten.transpose.int": None, "torch.ops.aten.unbind.int": None, + "torch.ops.aten.unfold.default": None, "torch.ops.aten.unsqueeze.default": None, "torch.ops.aten.upsample_nearest2d.default": None, + "torch.ops.aten.var.correction": None, "torch.ops.aten.var_mean.correction": None, "torch.ops.aten.view.default": None, "torch.ops.aten.where.self": None, From b5cb48e5f960e546ddfbebfa1fac170f772d96ec Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Wed, 28 Feb 2024 10:46:46 -0800 Subject: [PATCH 02/25] Additional ops for NetVLad and ALIKE --- .../src/openvino/frontend/pytorch/torchdynamo/op_support.py | 5 +++++ src/frontends/pytorch/src/op_table.cpp | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 9a233d86017ec3..255b47492d1983 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -96,6 +96,8 @@ def __init__(self, options): "torch.ops.aten.div.Scalar": None, "torch.ops.aten.div.Tensor": None, "torch.ops.aten.div.Tensor_mode": None, + "torch.ops.aten.elu.default": None, + "torch.ops.aten.elu_.default": None, "torch.ops.aten.embedding.default": None, "torch.ops.aten.empty.memory_format": None, "torch.ops.aten.eq.Scalar": None, @@ -119,6 +121,7 @@ def __init__(self, options): "torch.ops.aten.ge.Tensor": None, "torch.ops.aten.gelu.default": None, "torch.ops.aten.glu.default": None, + "torch.ops.aten.grid_sampler_2d.default": None, "torch.ops.aten.gt.Scalar": None, "torch.ops.aten.gt.Tensor": None, "torch.ops.aten.hardsigmoid.default": None, @@ -207,12 +210,14 @@ def __init__(self, options): "torch.ops.aten.unbind.int": None, "torch.ops.aten.unfold.default": None, "torch.ops.aten.unsqueeze.default": None, + "torch.ops.aten.upsample_bilinear2d.default": None, "torch.ops.aten.upsample_nearest2d.default": None, "torch.ops.aten.var.correction": None, "torch.ops.aten.var_mean.correction": None, "torch.ops.aten.view.default": None, "torch.ops.aten.where.self": None, "torch.ops.aten.zeros_like.default": None, + "torch.ops.torchvision.deform_conv2d.default": None, } for op in _get_disabled_ops(options): diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 26280066c90777..25284d9fe236b6 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -750,6 +750,8 @@ const std::map get_supported_ops_fx() { {"aten.div.Scalar", op::translate_div_fx}, {"aten.div.Tensor", op::translate_div_fx}, {"aten.div.Tensor_mode", op::translate_div_fx}, + {"aten.elu.default", op::translate_elu}, + {"aten.elu_.default", op::inplace_op}, {"aten.embedding.default", op::translate_embedding}, {"aten.empty.memory_format", op::translate_empty}, {"aten.eq.Scalar", op::translate_1to1_match_2_inputs_align_types}, @@ -773,6 +775,7 @@ const std::map get_supported_ops_fx() { {"aten.ge.Tensor", op::translate_1to1_match_2_inputs_align_types}, {"aten.gelu.default", op::translate_gelu_fx}, {"aten.glu.default", op::translate_glu}, + {"aten.grid_sampler_2d.default", op::translate_grid_sampler}, {"aten.gt.Scalar", op::translate_1to1_match_2_inputs_align_types}, {"aten.gt.Tensor", op::translate_1to1_match_2_inputs_align_types}, {"aten.hardsigmoid.default", op::translate_1to1_match_1_inputs}, @@ -861,6 +864,7 @@ const std::map get_supported_ops_fx() { {"aten.unbind.int", op::translate_unbind_int_fx}, {"aten.unfold.default", op::translate_unfold}, {"aten.unsqueeze.default", op::translate_1to1_match_2_inputs}, + {"aten.upsample_bilinear2d.default", op::translate_upsample_bilinear2d}, {"aten.upsample_nearest2d.default", op::translate_upsample_nearest2d}, {"aten.var.correction", op::translate_var_fx}, {"aten.var_mean.correction", op::translate_var_mean_fx}, @@ -879,7 +883,7 @@ const std::map get_supported_ops_fx() { {"prim::PythonOp", op::translate_pythonop}, {"prim::requires_grad", op::return_false_scalar}, {"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode. - {"torchvision::deform_conv2d", op::translate_deform_conv}, + {"torchvision.deform_conv2d.default", op::translate_deform_conv}, {"torchvision::nms", op::translate_nms}, {"torchvision::roi_align", op::translate_roi_align}, }; From 24cae2bdcfb144646a8fbf4cdf325da661987546 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Wed, 28 Feb 2024 14:28:12 -0800 Subject: [PATCH 03/25] Additional op support for ChatGLM2 --- .../pytorch/torchdynamo/op_support.py | 15 ++++++++ src/frontends/pytorch/src/op/any.cpp | 32 +++++++++++++++++ src/frontends/pytorch/src/op/bitwise.cpp | 20 +++++++++++ src/frontends/pytorch/src/op/cat.cpp | 12 +++++-- src/frontends/pytorch/src/op/isinf.cpp | 25 ++++++++++++++ src/frontends/pytorch/src/op/isnan.cpp | 25 ++++++++++++++ src/frontends/pytorch/src/op/sort.cpp | 30 ++++++++++++++++ src/frontends/pytorch/src/op/split.cpp | 6 ++-- src/frontends/pytorch/src/op/topk.cpp | 34 +++++++++++++++++++ src/frontends/pytorch/src/op/var_mean.cpp | 7 +++- src/frontends/pytorch/src/op_table.cpp | 24 +++++++++++-- 11 files changed, 221 insertions(+), 9 deletions(-) create mode 100644 src/frontends/pytorch/src/op/any.cpp create mode 100644 src/frontends/pytorch/src/op/isinf.cpp create mode 100644 src/frontends/pytorch/src/op/isnan.cpp diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 255b47492d1983..1790762c9a242d 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -57,8 +57,10 @@ def __init__(self, options): "torch.ops.aten.addcmul.default": None, "torch.ops.aten.addmm.default": None, "torch.ops.aten.alias.default": None, + "torch.ops.aten.all.default": None, "torch.ops.aten.amax.default": None, "torch.ops.aten.amin.default": None, + "torch.ops.aten.any.default": None, "torch.ops.aten.arange.default": None, "torch.ops.aten.arange.start": None, "torch.ops.aten.arange.start_step": None, @@ -75,6 +77,7 @@ def __init__(self, options): "torch.ops.aten.bitwise_and.Tensor": None, "torch.ops.aten.bitwise_not.default": None, "torch.ops.aten.bitwise_or.Tensor": None, + "torch.ops.aten.bitwise_right_shift.Tensor": None, "torch.ops.aten.bitwise_xor.Tensor": None, "torch.ops.aten.bmm.default": None, "torch.ops.aten.cat.default": None, @@ -93,6 +96,7 @@ def __init__(self, options): "torch.ops.aten.cosh.default": None, "torch.ops.aten.cumsum.default": None, "torch.ops.aten.detach.default": None, + "torch.ops.aten.detach_.default": None, "torch.ops.aten.div.Scalar": None, "torch.ops.aten.div.Tensor": None, "torch.ops.aten.div.Tensor_mode": None, @@ -108,6 +112,7 @@ def __init__(self, options): "torch.ops.aten.fake_quantize_per_channel_affine_cachemask.default": None, "torch.ops.aten.fill.Scalar": None, "torch.ops.aten.fill.Tensor": None, + "torch.ops.aten.fill_.Tensor": None, "torch.ops.aten.flip.default": None, "torch.ops.aten.floor.default": None, "torch.ops.aten.floor.default": None, @@ -131,6 +136,8 @@ def __init__(self, options): "torch.ops.aten.hardtanh_.default": None, "torch.ops.aten.index.Tensor": None, "torch.ops.aten.index_select.default": None, + "torch.ops.aten.isinf.default": None, + "torch.ops.aten.isnan.default": None, "torch.ops.aten.le.Scalar": None, "torch.ops.aten.le.Tensor": None, "torch.ops.aten.leaky_relu.default": None, @@ -142,10 +149,12 @@ def __init__(self, options): "torch.ops.aten.log10.default": None, "torch.ops.aten.log1p.default": None, "torch.ops.aten.log2.default": None, + "torch.ops.aten.logical_not.default": None, "torch.ops.aten.logsumexp.default": None, "torch.ops.aten.lt.Scalar": None, "torch.ops.aten.lt.Tensor": None, "torch.ops.aten.masked_fill.Tensor": None, + "torch.ops.aten.masked_fill.Scalar": None, "torch.ops.aten.masked_fill_.Scalar": None, "torch.ops.aten.max.default": None, "torch.ops.aten.max.dim": None, @@ -169,6 +178,7 @@ def __init__(self, options): "torch.ops.aten.neg.default": None, "torch.ops.aten.new_full.default": None, "torch.ops.aten.new_ones.default": None, + "torch.ops.aten.ones.default": None, "torch.ops.aten.permute.default": None, "torch.ops.aten.pow.Scalar": None, "torch.ops.aten.pow.Tensor_Scalar": None, @@ -182,6 +192,7 @@ def __init__(self, options): "torch.ops.aten.rsub.Scalar": None, "torch.ops.aten.rsub.Tensor": None, "torch.ops.aten.scalar_tensor.default": None, + "torch.ops.aten.scatter.src": None, "torch.ops.aten.scatter.value": None, "torch.ops.aten.select.int": None, "torch.ops.aten.select_scatter.default": None, @@ -193,6 +204,7 @@ def __init__(self, options): "torch.ops.aten.sinh.default": None, "torch.ops.aten.slice.Tensor": None, "torch.ops.aten.slice_scatter.default": None, + "torch.ops.aten.sort.default": None, "torch.ops.aten.split.Tensor": None, "torch.ops.aten.split_with_sizes.default": None, "torch.ops.aten.sqrt.default": None, @@ -206,7 +218,10 @@ def __init__(self, options): "torch.ops.aten.t.default": None, "torch.ops.aten.tan.default": None, "torch.ops.aten.tanh.default": None, + "torch.ops.aten.topk.default": None, "torch.ops.aten.transpose.int": None, + "torch.ops.aten.tril.default": None, + "torch.ops.aten.tril_.default": None, "torch.ops.aten.unbind.int": None, "torch.ops.aten.unfold.default": None, "torch.ops.aten.unsqueeze.default": None, diff --git a/src/frontends/pytorch/src/op/any.cpp b/src/frontends/pytorch/src/op/any.cpp new file mode 100644 index 00000000000000..32f4664e8b3195 --- /dev/null +++ b/src/frontends/pytorch/src/op/any.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/reduce_logical_or.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/not_equal.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_any_fx(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto x = context.get_input(0); + auto num_inputs = context.get_input_size(); + bool keep_dims = false; + auto const_minus_one = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {-1})); + auto flatten_source = context.mark_node(std::make_shared(x, const_minus_one, false)); + auto const_zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0})); + auto any = context.mark_node(std::make_shared(flatten_source, const_zero, keep_dims)); + return {any}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/bitwise.cpp b/src/frontends/pytorch/src/op/bitwise.cpp index 84465502969d81..c9a286a2a97ac6 100644 --- a/src/frontends/pytorch/src/op/bitwise.cpp +++ b/src/frontends/pytorch/src/op/bitwise.cpp @@ -7,6 +7,8 @@ #include "openvino/op/bitwise_not.hpp" #include "openvino/op/bitwise_or.hpp" #include "openvino/op/bitwise_xor.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/logical_not.hpp" #include "utils.hpp" namespace ov { @@ -60,6 +62,24 @@ OutputVector translate_bitwise_xor(const NodeContext& context) { return {xor_x}; }; +OutputVector translate_bitwise_not_fx(const NodeContext& context) { + num_inputs_check(context, 1, 2); + auto x = context.get_input(0); + if (x.get_element_type() != element::boolean) { + auto x_bool = context.mark_node(std::make_shared(x, element::boolean)); + auto not_x = context.mark_node(std::make_shared(x_bool)); + if (!context.input_is_none(1)) { + context.mutate_input(1, not_x); + } + return {not_x}; + } + auto not_x = context.mark_node(std::make_shared(x)); + if (!context.input_is_none(1)) { + context.mutate_input(1, not_x); + } + return {not_x}; +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp index 3baec6fea4db05..33d2f5f18a20fd 100644 --- a/src/frontends/pytorch/src/op/cat.cpp +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -106,9 +106,7 @@ OutputVector translate_stack_fx(const NodeContext& context) { auto dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); std::deque> list_elems; auto num_elements = context.get_input_size(); - if (num_elements > 2) - num_elements = num_elements - 1; - for (size_t i = 0; i < num_elements; i++) { + for (size_t i = 0; i < num_elements - 1; i++) { auto stack_input = context.mark_node(std::make_shared(context.get_input(static_cast(i)), dim)); list_elems.push_back(stack_input); @@ -116,6 +114,14 @@ OutputVector translate_stack_fx(const NodeContext& context) { int64_t axis = 0; if (context.get_input_size() > 2) axis = context.const_input(context.get_input_size() - 1); + if (!context.get_input_type(context.get_input_size() - 1).is()) { + // axis can be not present and that means that last input will have List type + axis = context.const_input(context.get_input_size() - 1); + } else { + auto stack_input = + context.mark_node(std::make_shared(context.get_input(static_cast(context.get_input_size() - 1)), dim)); + list_elems.push_back(stack_input); + } return translate_cat_common(context, list_elems, axis, true); } diff --git a/src/frontends/pytorch/src/op/isinf.cpp b/src/frontends/pytorch/src/op/isinf.cpp new file mode 100644 index 00000000000000..1a100391931f61 --- /dev/null +++ b/src/frontends/pytorch/src/op/isinf.cpp @@ -0,0 +1,25 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/add.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/is_inf.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_isinf_fx(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto input = context.get_input(0); + return {context.mark_node(std::make_shared(input))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/isnan.cpp b/src/frontends/pytorch/src/op/isnan.cpp new file mode 100644 index 00000000000000..d995546b22b37e --- /dev/null +++ b/src/frontends/pytorch/src/op/isnan.cpp @@ -0,0 +1,25 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/add.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/is_nan.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_isnan_fx(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto input = context.get_input(0); + return {context.mark_node(std::make_shared(input))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/sort.cpp b/src/frontends/pytorch/src/op/sort.cpp index 7e75a98c6cebd8..200844c6c10466 100644 --- a/src/frontends/pytorch/src/op/sort.cpp +++ b/src/frontends/pytorch/src/op/sort.cpp @@ -41,6 +41,36 @@ OutputVector translate_sort(const NodeContext& context) { return topk->outputs(); }; +OutputVector translate_sort_fx(const NodeContext& context) { + // aten.sort.default(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + num_inputs_check(context, 1, 3); + const auto input_tensor = context.get_input(0); + bool descending = false; + int64_t dim = -1; + + if (!context.input_is_none(1)) { + dim = context.const_input(1); + } + if (!context.input_is_none(1)) { + descending = context.const_input(2); + } + + auto mode = descending ? ov::op::TopKMode::MAX : ov::op::TopKMode::MIN; + auto zero_axis = context.mark_node(opset11::Constant::create(element::i32, Shape{1}, {0})); + auto dim_axis = context.mark_node(opset11::Constant::create(element::i64, Shape{1}, {dim})); + auto shape = context.mark_node(std::make_shared(input_tensor)); + auto k_values_node = context.mark_node(std::make_shared(shape, dim_axis, zero_axis)); + auto k_values = context.mark_node(std::make_shared(k_values_node)); + auto topk = context.mark_node(std::make_shared(input_tensor, + k_values, + dim, + mode, + ov::op::TopKSortType::SORT_VALUES, + element::i64)); + auto indices = context.mark_node(std::make_shared(topk->output(1), element::i64)); + return {context.mark_node(make_list_construct(OutputVector({topk->output(0), indices})))}; +}; + OutputVector translate_argsort(const NodeContext& context) { auto sort = translate_sort(context); return {sort[1]}; diff --git a/src/frontends/pytorch/src/op/split.cpp b/src/frontends/pytorch/src/op/split.cpp index e983c3031d0f91..b8345a0b4a9700 100644 --- a/src/frontends/pytorch/src/op/split.cpp +++ b/src/frontends/pytorch/src/op/split.cpp @@ -25,11 +25,11 @@ OutputVector translate_chunk_fx(const NodeContext& context) { std::shared_ptr chunk; auto dim_val = context.const_input(2); - auto shape = context.get_input(0).get_shape(); + auto shape = context.get_input(0).get_partial_shape(); if (dim_val < 0) { - dim_val = static_cast(shape.size()) + dim_val; + dim_val = static_cast(shape.rank().get_length()) + dim_val; } - int num_splits = static_cast(shape[dim_val]) / num_chunks; + int num_splits = static_cast(shape[dim_val].get_length()) / num_chunks; chunk = context.mark_node(std::make_shared(context.get_input(0), dim, num_splits)); diff --git a/src/frontends/pytorch/src/op/topk.cpp b/src/frontends/pytorch/src/op/topk.cpp index 06916c4ea03e2f..345cbabb86299d 100644 --- a/src/frontends/pytorch/src/op/topk.cpp +++ b/src/frontends/pytorch/src/op/topk.cpp @@ -5,6 +5,7 @@ #include "openvino/op/topk.hpp" #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/util/framework_node.hpp" #include "openvino/op/convert.hpp" #include "utils.hpp" @@ -41,6 +42,39 @@ OutputVector translate_topk(const NodeContext& context) { return {topk->output(0), indices}; }; +OutputVector translate_topk_fx(const NodeContext& context) { + // aten.topk.default(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> Tuple[Tensor, Tensor] + num_inputs_check(context, 2, 5); + const auto input_tensor = context.get_input(0); + auto k = context.get_input(1); + int64_t axis{-1}; + bool largest = true; + bool sorted = true; + auto mode = TopKMode::MIN; + auto sort = TopKSortType::NONE; + + if (!context.input_is_none(2)) { + axis = context.const_input(2); + } + if (!context.input_is_none(3)) { + largest = context.const_input(3); + } + if (!context.input_is_none(4)) { + sorted = context.const_input(4); + } + if (largest) { + mode = TopKMode::MAX; + } + if (sorted) { + sort = TopKSortType::SORT_VALUES; + } + + auto topk = context.mark_node(std::make_shared(input_tensor, k, axis, mode, sort)); + auto indices = context.mark_node(std::make_shared(topk->output(1), element::i64)); + + return {context.mark_node(make_list_construct(OutputVector({topk->output(0), indices})))}; +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op/var_mean.cpp b/src/frontends/pytorch/src/op/var_mean.cpp index c4937a1c6d888f..8c7056a329b6ca 100644 --- a/src/frontends/pytorch/src/op/var_mean.cpp +++ b/src/frontends/pytorch/src/op/var_mean.cpp @@ -145,6 +145,11 @@ OutputVector translate_var(const NodeContext& context) { return {res[0]}; } +OutputVector translate_var_correction_fx(const NodeContext& context) { + auto res = translate_var_mean(context); + return {context.mark_node(make_list_construct(res))}; +} + OutputVector translate_std(const NodeContext& context) { auto res = translate_var_mean(context); auto var = res[0]; @@ -160,4 +165,4 @@ OutputVector translate_std_mean(const NodeContext& context) { } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 25284d9fe236b6..dfef7ae6b08b49 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -243,7 +243,9 @@ OP_CONVERTER(translate_adaptive_max_pool2d_fx); OP_CONVERTER(translate_adaptive_max_pool3d_fx); OP_CONVERTER(translate_addcmul_fx); OP_CONVERTER(translate_addmm_fx); +OP_CONVERTER(translate_any_fx); OP_CONVERTER(translate_arange_fx); +OP_CONVERTER(translate_bitwise_not_fx); OP_CONVERTER(translate_batch_norm_legit_fx); OP_CONVERTER(translate_batch_norm_legit_no_training_fx); OP_CONVERTER(translate_batch_norm_legit_no_stats_fx); @@ -259,6 +261,8 @@ OP_CONVERTER(translate_full_fx); OP_CONVERTER(translate_gelu_fx); OP_CONVERTER(translate_group_norm_fx); OP_CONVERTER(translate_index_fx); +OP_CONVERTER(translate_isinf_fx); +OP_CONVERTER(translate_isnan_fx); OP_CONVERTER(translate_layer_norm_fx); OP_CONVERTER(translate_leaky_relu_fx); OP_CONVERTER(translate_log_softmax_fx); @@ -273,10 +277,12 @@ OP_CONVERTER(translate_select_scatter_fx); OP_CONVERTER(translate_slice_fx); OP_CONVERTER(translate_slice_scatter_fx); OP_CONVERTER(translate_softmax_fx); +OP_CONVERTER(translate_sort_fx); OP_CONVERTER(translate_split_with_sizes_fx); OP_CONVERTER(translate_stack_fx); OP_CONVERTER(translate_sub_fx); OP_CONVERTER(translate_sum_fx); +OP_CONVERTER(translate_topk_fx); OP_CONVERTER(translate_to_fx); OP_CONVERTER(translate_transpose_fx); OP_CONVERTER(translate_var_fx); @@ -711,8 +717,10 @@ const std::map get_supported_ops_fx() { {"aten.addcmul.default", op::translate_addcmul_fx}, {"aten.addmm.default", op::translate_addmm_fx}, {"aten.alias.default", op::skip_node}, + {"aten.all.default", op::translate_all}, {"aten.amax.default", op::translate_amax}, {"aten.amin.default", op::translate_amin}, + {"aten.any.default", op::translate_any_fx}, {"aten.arange.default", op::translate_arange_fx}, {"aten.arange.start", op::translate_arange_fx}, {"aten.arange.start_step", op::translate_arange_fx}, @@ -727,7 +735,7 @@ const std::map get_supported_ops_fx() { {"aten.avg_pool3d.default", op::translate_avg_poolnd}, {"aten.baddbmm.default", op::translate_addmm_fx}, {"aten.bitwise_and.Tensor", op::translate_bitwise_and}, - {"aten.bitwise_not.default", op::translate_bitwise_not}, + {"aten.bitwise_not.default", op::translate_bitwise_not_fx}, {"aten.bitwise_or.Tensor", op::translate_bitwise_or}, {"aten.bitwise_xor.Tensor", op::translate_bitwise_xor}, {"aten.bmm.default", op::translate_1to1_match_2_inputs_align_types}, @@ -747,6 +755,7 @@ const std::map get_supported_ops_fx() { {"aten.cosh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.cumsum.default", op::translate_cumsum_fx}, {"aten.detach.default", op::skip_node}, + {"aten.detach_.default", op::skip_node}, {"aten.div.Scalar", op::translate_div_fx}, {"aten.div.Tensor", op::translate_div_fx}, {"aten.div.Tensor_mode", op::translate_div_fx}, @@ -762,6 +771,7 @@ const std::map get_supported_ops_fx() { {"aten.fake_quantize_per_channel_affine_cachemask.default", op::translate_fake_quantize_per_channel_affine_fx}, {"aten.fill.Scalar", op::translate_fill}, {"aten.fill.Tensor", op::translate_fill}, + {"aten.fill_.Tensor", op::inplace_op}, {"aten.flip.default", op::translate_flip}, {"aten.floor.default", op::translate_1to1_match_1_inputs}, {"aten.floor_divide.default", op::translate_floor_divide}, @@ -785,6 +795,8 @@ const std::map get_supported_ops_fx() { {"aten.hardtanh_.default", op::inplace_op}, {"aten.index.Tensor", op::translate_index_fx}, {"aten.index_select.default", op::translate_index_select}, + {"aten.isinf.default", op::translate_isinf_fx}, + {"aten.isnan.default", op::translate_isnan_fx}, {"aten.le.Scalar", op::translate_1to1_match_2_inputs_align_types}, {"aten.le.Tensor", op::translate_1to1_match_2_inputs_align_types}, {"aten.leaky_relu.default", op::translate_leaky_relu_fx}, @@ -799,8 +811,10 @@ const std::map get_supported_ops_fx() { {"aten.logsumexp.default", op::translate_logsumexp}, {"aten.lt.Scalar", op::translate_1to1_match_2_inputs_align_types}, {"aten.lt.Tensor", op::translate_1to1_match_2_inputs_align_types}, - {"aten.masked_fill.Tensor", op::translate_masked_fill}, + {"aten.masked_fill.Scalar", op::translate_masked_fill}, {"aten.masked_fill_.Scalar", op::inplace_op}, + {"aten.masked_fill.Tensor", op::translate_masked_fill}, + {"aten.masked_fill_.Tensor", op::inplace_op}, {"aten.max.default", op::translate_max}, {"aten.max.dim", op::translate_max_dim_fx}, {"aten.max_pool2d_with_indices.default", op::translate_max_poolnd_fx}, @@ -823,6 +837,7 @@ const std::map get_supported_ops_fx() { {"aten.neg.default", op::translate_neg}, {"aten.new_full.default", op::translate_new_full}, {"aten.new_ones.default", op::translate_new_ones}, + {"aten.ones.default", op::translate_ones}, {"aten.permute.default", op::translate_1to1_match_2_inputs}, {"aten.pow.Scalar", op::translate_pow}, {"aten.pow.Tensor_Scalar", op::translate_pow}, @@ -836,6 +851,7 @@ const std::map get_supported_ops_fx() { {"aten.rsub.Scalar", op::translate_rsub_fx}, {"aten.rsub.Tensor", op::translate_rsub_fx}, {"aten.scalar_tensor.default", op::translate_scalar_tensor_fx}, + {"aten.scatter.src", op::translate_scatter}, {"aten.scatter.value", op::translate_scatter}, {"aten.select.int", op::translate_select}, {"aten.select_scatter.default", op::translate_select_scatter_fx}, @@ -847,12 +863,14 @@ const std::map get_supported_ops_fx() { {"aten.sinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.slice.Tensor", op::translate_slice_fx}, {"aten.slice_scatter.default", op::translate_slice_scatter_fx}, + {"aten.sort.default", op::translate_sort_fx}, {"aten.split.Tensor", op::translate_chunk_fx}, {"aten.split_with_sizes.default", op::translate_split_with_sizes_fx}, {"aten.sqrt.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.squeeze.dim", op::translate_squeeze}, {"aten.squeeze.dims", op::translate_squeeze}, {"aten.stack.default", op::translate_stack_fx}, + {"aten.stack.default", op::translate_stack_fx}, {"aten.sub.default", op::translate_sub_fx}, {"aten.sub.Tensor", op::translate_sub_fx}, {"aten.sum.default", op::translate_sum_fx}, @@ -860,7 +878,9 @@ const std::map get_supported_ops_fx() { {"aten.t.default", op::translate_t}, {"aten.tan.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.tanh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, + {"aten.topk.default", op::translate_topk_fx}, {"aten.transpose.int", op::translate_transpose}, + {"aten.tril.default", op::translate_tril}, {"aten.unbind.int", op::translate_unbind_int_fx}, {"aten.unfold.default", op::translate_unfold}, {"aten.unsqueeze.default", op::translate_1to1_match_2_inputs}, From 16bbbd608368deb72ab82ba764abde735781fdd4 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Wed, 28 Feb 2024 15:30:06 -0800 Subject: [PATCH 04/25] PTFE input_model constructor output access fix --- src/frontends/pytorch/src/input_model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/input_model.cpp b/src/frontends/pytorch/src/input_model.cpp index bce1937151d92d..765b4b70f3ce89 100644 --- a/src/frontends/pytorch/src/input_model.cpp +++ b/src/frontends/pytorch/src/input_model.cpp @@ -24,7 +24,7 @@ InputModel::InputModel(const std::shared_ptr& model_decoder) : m_m const auto& outputs = m_model_decoder->outputs(); for (size_t i = 0; i < outputs.size(); ++i) { auto out_place = std::make_shared(*this, outputs[i]); - m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast(out_place)); + m_name_to_place.emplace(std::to_string(outputs[i]), std::dynamic_pointer_cast(out_place)); for (const auto& name : out_place->get_names()) { m_name_to_place.emplace(name, std::dynamic_pointer_cast(out_place)); } From 50b9c63f05a4c83a44228a372b8d2a33508f2f39 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Thu, 29 Feb 2024 14:14:08 -0800 Subject: [PATCH 05/25] Removed bitwise_not fx version --- src/frontends/pytorch/src/op/bitwise.cpp | 18 ------------------ src/frontends/pytorch/src/op_table.cpp | 3 +-- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/src/frontends/pytorch/src/op/bitwise.cpp b/src/frontends/pytorch/src/op/bitwise.cpp index c9a286a2a97ac6..0c91278e27dadf 100644 --- a/src/frontends/pytorch/src/op/bitwise.cpp +++ b/src/frontends/pytorch/src/op/bitwise.cpp @@ -62,24 +62,6 @@ OutputVector translate_bitwise_xor(const NodeContext& context) { return {xor_x}; }; -OutputVector translate_bitwise_not_fx(const NodeContext& context) { - num_inputs_check(context, 1, 2); - auto x = context.get_input(0); - if (x.get_element_type() != element::boolean) { - auto x_bool = context.mark_node(std::make_shared(x, element::boolean)); - auto not_x = context.mark_node(std::make_shared(x_bool)); - if (!context.input_is_none(1)) { - context.mutate_input(1, not_x); - } - return {not_x}; - } - auto not_x = context.mark_node(std::make_shared(x)); - if (!context.input_is_none(1)) { - context.mutate_input(1, not_x); - } - return {not_x}; -}; - } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index dfef7ae6b08b49..d926a1a8e073c4 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -245,7 +245,6 @@ OP_CONVERTER(translate_addcmul_fx); OP_CONVERTER(translate_addmm_fx); OP_CONVERTER(translate_any_fx); OP_CONVERTER(translate_arange_fx); -OP_CONVERTER(translate_bitwise_not_fx); OP_CONVERTER(translate_batch_norm_legit_fx); OP_CONVERTER(translate_batch_norm_legit_no_training_fx); OP_CONVERTER(translate_batch_norm_legit_no_stats_fx); @@ -735,7 +734,7 @@ const std::map get_supported_ops_fx() { {"aten.avg_pool3d.default", op::translate_avg_poolnd}, {"aten.baddbmm.default", op::translate_addmm_fx}, {"aten.bitwise_and.Tensor", op::translate_bitwise_and}, - {"aten.bitwise_not.default", op::translate_bitwise_not_fx}, + {"aten.bitwise_not.default", op::translate_bitwise_not}, {"aten.bitwise_or.Tensor", op::translate_bitwise_or}, {"aten.bitwise_xor.Tensor", op::translate_bitwise_xor}, {"aten.bmm.default", op::translate_1to1_match_2_inputs_align_types}, From 5e907a67d58d8b118f3b74797ec4b04b6dad8a4a Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Thu, 29 Feb 2024 18:44:19 -0800 Subject: [PATCH 06/25] Additional op support for TorchFX --- .../src/openvino/frontend/pytorch/torchdynamo/op_support.py | 3 ++- src/frontends/pytorch/src/op_table.cpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 1790762c9a242d..0b4e69624c4aaa 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -111,6 +111,7 @@ def __init__(self, options): "torch.ops.aten.expand.default": None, "torch.ops.aten.fake_quantize_per_channel_affine_cachemask.default": None, "torch.ops.aten.fill.Scalar": None, + "torch.ops.aten.fill_.Scalar": None, "torch.ops.aten.fill.Tensor": None, "torch.ops.aten.fill_.Tensor": None, "torch.ops.aten.flip.default": None, @@ -225,7 +226,6 @@ def __init__(self, options): "torch.ops.aten.unbind.int": None, "torch.ops.aten.unfold.default": None, "torch.ops.aten.unsqueeze.default": None, - "torch.ops.aten.upsample_bilinear2d.default": None, "torch.ops.aten.upsample_nearest2d.default": None, "torch.ops.aten.var.correction": None, "torch.ops.aten.var_mean.correction": None, @@ -233,6 +233,7 @@ def __init__(self, options): "torch.ops.aten.where.self": None, "torch.ops.aten.zeros_like.default": None, "torch.ops.torchvision.deform_conv2d.default": None, + "torch.ops.torchvision.roi_align.default": None, } for op in _get_disabled_ops(options): diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index d926a1a8e073c4..3f91f55ae42272 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -769,6 +769,7 @@ const std::map get_supported_ops_fx() { {"aten.expand.default", op::translate_expand_fx}, {"aten.fake_quantize_per_channel_affine_cachemask.default", op::translate_fake_quantize_per_channel_affine_fx}, {"aten.fill.Scalar", op::translate_fill}, + {"aten.fill_.Scalar", op::inplace_op}, {"aten.fill.Tensor", op::translate_fill}, {"aten.fill_.Tensor", op::inplace_op}, {"aten.flip.default", op::translate_flip}, @@ -883,7 +884,6 @@ const std::map get_supported_ops_fx() { {"aten.unbind.int", op::translate_unbind_int_fx}, {"aten.unfold.default", op::translate_unfold}, {"aten.unsqueeze.default", op::translate_1to1_match_2_inputs}, - {"aten.upsample_bilinear2d.default", op::translate_upsample_bilinear2d}, {"aten.upsample_nearest2d.default", op::translate_upsample_nearest2d}, {"aten.var.correction", op::translate_var_fx}, {"aten.var_mean.correction", op::translate_var_mean_fx}, From 3f9a245a152cf7f3dc83f07a2e5fd4f2e86b801b Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Thu, 29 Feb 2024 23:34:06 -0800 Subject: [PATCH 07/25] Stack translation fix for TorchFX --- src/frontends/pytorch/src/op/cat.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp index 33d2f5f18a20fd..a200ad3f997acf 100644 --- a/src/frontends/pytorch/src/op/cat.cpp +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -102,7 +102,7 @@ OutputVector translate_quantized_cat(const NodeContext& context) { }; OutputVector translate_stack_fx(const NodeContext& context) { - num_inputs_check(context, 2, context.get_input_size()); + num_inputs_check(context, 1, context.get_input_size()); auto dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); std::deque> list_elems; auto num_elements = context.get_input_size(); @@ -112,14 +112,12 @@ OutputVector translate_stack_fx(const NodeContext& context) { list_elems.push_back(stack_input); } int64_t axis = 0; - if (context.get_input_size() > 2) - axis = context.const_input(context.get_input_size() - 1); - if (!context.get_input_type(context.get_input_size() - 1).is()) { + if (!context.get_input_type(num_elements - 1).is()) { // axis can be not present and that means that last input will have List type - axis = context.const_input(context.get_input_size() - 1); + axis = context.const_input(num_elements - 1); } else { auto stack_input = - context.mark_node(std::make_shared(context.get_input(static_cast(context.get_input_size() - 1)), dim)); + context.mark_node(std::make_shared(context.get_input(static_cast(num_elements - 1)), dim)); list_elems.push_back(stack_input); } return translate_cat_common(context, list_elems, axis, true); From b4dde851a0114a5b7fbf011b422c540b1df1fca4 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Mon, 4 Mar 2024 20:38:40 -0800 Subject: [PATCH 08/25] TorchFX unit tests for Div and Elu --- tests/layer_tests/pytorch_tests/test_div.py | 1 + tests/layer_tests/pytorch_tests/test_elu.py | 38 +++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 tests/layer_tests/pytorch_tests/test_elu.py diff --git a/tests/layer_tests/pytorch_tests/test_div.py b/tests/layer_tests/pytorch_tests/test_div.py index a2809b1fa68899..eaccadae7ba81a 100644 --- a/tests/layer_tests/pytorch_tests/test_div.py +++ b/tests/layer_tests/pytorch_tests/test_div.py @@ -44,6 +44,7 @@ def forward(self, input_tensor, other_tensor): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_div_pt_spec(self, input_array, other_array, rounding_mode, ie_device, precision, ir_version): self.input_array = input_array self.input_type = np.float32 diff --git a/tests/layer_tests/pytorch_tests/test_elu.py b/tests/layer_tests/pytorch_tests/test_elu.py new file mode 100644 index 00000000000000..eb2dc76215405f --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_elu.py @@ -0,0 +1,38 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestElu(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 3).astype(np.float32),) + + def create_model(self, alpha, inplace): + import torch + import torch.nn.functional as F + + class aten_elu(torch.nn.Module): + def __init__(self, alpha, inplace): + super(aten_elu, self).__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, x): + return F.elu(x, alpha=self.alpha, inplace=self.inplace) + + ref_net = None + + return aten_elu(alpha, inplace), ref_net, "aten::elu" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize("alpha", [1.0, 0.5]) + @pytest.mark.parametrize("inplace", [True, False]) + def test_elu(self, alpha, inplace, ie_device, precision, ir_version): + self._test(*self.create_model(alpha, inplace), ie_device, precision, ir_version) From a363a072c8afbefb874f08db84e8acedee0de380 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Mon, 4 Mar 2024 21:19:19 -0800 Subject: [PATCH 09/25] TorchFX unit test update: full, comparison, glu --- tests/layer_tests/pytorch_tests/test_comparision.py | 2 ++ tests/layer_tests/pytorch_tests/test_full.py | 1 + tests/layer_tests/pytorch_tests/test_glu.py | 3 ++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/layer_tests/pytorch_tests/test_comparision.py b/tests/layer_tests/pytorch_tests/test_comparision.py index 969079d8e88cf7..79906ef3c95b30 100644 --- a/tests/layer_tests/pytorch_tests/test_comparision.py +++ b/tests/layer_tests/pytorch_tests/test_comparision.py @@ -55,6 +55,7 @@ def forward(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_comp(self, op, ie_device, precision, ir_version): self._test(*self.create_model(op), ie_device, precision, ir_version, use_convert_model=True) @@ -127,6 +128,7 @@ def forward3(self, lhs, rhs): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_eq_mixed_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape, op): self.lhs_type = lhs_type self.lhs_shape = lhs_shape diff --git a/tests/layer_tests/pytorch_tests/test_full.py b/tests/layer_tests/pytorch_tests/test_full.py index cf3794be11e891..3264217c2b658d 100644 --- a/tests/layer_tests/pytorch_tests/test_full.py +++ b/tests/layer_tests/pytorch_tests/test_full.py @@ -84,6 +84,7 @@ def forward(self, x: float): @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_full(self, shape, value, ie_device, precision, ir_version): self._test(*self.create_model(shape), ie_device, precision, ir_version, kwargs_to_prepare_input={'value': value}) diff --git a/tests/layer_tests/pytorch_tests/test_glu.py b/tests/layer_tests/pytorch_tests/test_glu.py index aa77fb7d77664c..3011755bb8a5f1 100644 --- a/tests/layer_tests/pytorch_tests/test_glu.py +++ b/tests/layer_tests/pytorch_tests/test_glu.py @@ -30,6 +30,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dim", [0, 1, 2, 3, -1, -2]) def test_glu(self, dim, ie_device, precision, ir_version): - self._test(*self.create_model(dim), ie_device, precision, ir_version) \ No newline at end of file + self._test(*self.create_model(dim), ie_device, precision, ir_version) From 33da1babaae44d662a963b6c9a8409d63d03d289 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Mon, 4 Mar 2024 21:56:38 -0800 Subject: [PATCH 10/25] TorchFX unit tests: grip_sample_2d, hardswish, hardtanh --- .../pytorch_tests/test_grid_sampler.py | 1 + .../pytorch_tests/test_hardswish.py | 38 +++++++++++++++++ .../pytorch_tests/test_hardtanh.py | 41 +++++++++++++++++++ 3 files changed, 80 insertions(+) create mode 100644 tests/layer_tests/pytorch_tests/test_hardswish.py create mode 100644 tests/layer_tests/pytorch_tests/test_hardtanh.py diff --git a/tests/layer_tests/pytorch_tests/test_grid_sampler.py b/tests/layer_tests/pytorch_tests/test_grid_sampler.py index 7b55862e2f0c2d..ca858e89a4e410 100644 --- a/tests/layer_tests/pytorch_tests/test_grid_sampler.py +++ b/tests/layer_tests/pytorch_tests/test_grid_sampler.py @@ -37,6 +37,7 @@ def forward(self, input, grid): @pytest.mark.parametrize("align_corners", [True, False, None]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_grid_sampler(self, h_in, w_in, h_out, w_out, mode, padding_mode, align_corners, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_hardswish.py b/tests/layer_tests/pytorch_tests/test_hardswish.py new file mode 100644 index 00000000000000..f054167372cf94 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_hardswish.py @@ -0,0 +1,38 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import platform + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestHardswish(PytorchLayerTest): + def _prepare_input(self): + return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),) + + def create_model(self, inplace): + import torch + import torch.nn.functional as F + + class aten_hardswish(torch.nn.Module): + def __init__(self, inplace): + super(aten_hardswish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return F.hardswish(x, inplace=self.inplace) + + ref_net = None + + return aten_hardswish(inplace), ref_net, "aten::hardswish" + + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_fx_backend + def test_hardswish(self, inplace, ie_device, precision, ir_version): + self._test(*self.create_model(inplace), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_hardtanh.py b/tests/layer_tests/pytorch_tests/test_hardtanh.py new file mode 100644 index 00000000000000..a1aeb565d9d27f --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_hardtanh.py @@ -0,0 +1,41 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import platform + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestHardtanh(PytorchLayerTest): + def _prepare_input(self): + return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),) + + def create_model(self, min_val, max_val, inplace): + import torch + import torch.nn.functional as F + + class aten_hardtanh(torch.nn.Module): + def __init__(self, min_val, max_val, inplace): + super(aten_hardtanh, self).__init__() + self.min_val = min_val + self.max_val = max_val + self.inplace = inplace + + def forward(self, x): + return F.hardtanh(x, min_val=self.min_val, max_val=self.max_val, inplace=self.inplace) + + ref_net = None + + return aten_hardtanh(min_val, max_val, inplace), ref_net, "aten::hardtanh" + + @pytest.mark.parametrize(("min_val", "max_val"), [[-1.0,1.0], [0, 1.0], [-2.0, 2.0]]) + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_fx_backend + def test_hardtanh(self, min_val, max_val, inplace, ie_device, precision, ir_version): + self._test(*self.create_model(min_val, max_val, inplace), ie_device, precision, ir_version) From a3f919ada951efb0e91249dcc94a81e4a2f6adc0 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Mon, 4 Mar 2024 22:39:12 -0800 Subject: [PATCH 11/25] TorchFX unit tests: index_select, unary_ops, isinf, isnan --- .../pytorch_tests/test_index_select.py | 1 + tests/layer_tests/pytorch_tests/test_isinf.py | 33 +++++++++++++++++++ tests/layer_tests/pytorch_tests/test_isnan.py | 33 +++++++++++++++++++ .../pytorch_tests/test_unary_ops.py | 1 + 4 files changed, 68 insertions(+) create mode 100644 tests/layer_tests/pytorch_tests/test_isinf.py create mode 100644 tests/layer_tests/pytorch_tests/test_isnan.py diff --git a/tests/layer_tests/pytorch_tests/test_index_select.py b/tests/layer_tests/pytorch_tests/test_index_select.py index 1c8c2c91eb331c..1ba5d6c569cd6d 100644 --- a/tests/layer_tests/pytorch_tests/test_index_select.py +++ b/tests/layer_tests/pytorch_tests/test_index_select.py @@ -41,6 +41,7 @@ def forward_out(self, x, indices, out): @pytest.mark.parametrize("out", [False, True]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_index_select(self, dim, out, indices, ie_device, precision, ir_version): self._test(*self.create_model(dim, out), ie_device, precision, ir_version, kwargs_to_prepare_input={"index": indices, "out": out, "dim": dim}) diff --git a/tests/layer_tests/pytorch_tests/test_isinf.py b/tests/layer_tests/pytorch_tests/test_isinf.py new file mode 100644 index 00000000000000..70be61a5070d9f --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_isinf.py @@ -0,0 +1,33 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +@pytest.mark.parametrize('input_tensor', (torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))) +class TestIsInf(PytorchLayerTest): + + def _prepare_input(self): + input_tensor = self.input_tensor + return (input_tensor,) + + def create_model(self): + class aten_isinf(torch.nn.Module): + + def forward(self, input_tensor): + return torch.isinf(input_tensor) + + ref_net = None + + return aten_isinf(), ref_net, "aten::isinf" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_fx_backend + def test_isinf(self, ie_device, precision, ir_version, input_tensor): + self.input_tensor = input_tensor + self._test(*self.create_model(), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_isnan.py b/tests/layer_tests/pytorch_tests/test_isnan.py new file mode 100644 index 00000000000000..b2b5325ba11107 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_isnan.py @@ -0,0 +1,33 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +@pytest.mark.parametrize('input_tensor', (torch.tensor([1, float('nan'), 2]))) +class TestIsNan(PytorchLayerTest): + + def _prepare_input(self): + input_tensor = self.input_tensor + return (input_tensor,) + + def create_model(self): + class aten_isnan(torch.nn.Module): + + def forward(self, input_tensor): + return torch.isnan(input_tensor) + + ref_net = None + + return aten_isnan(), ref_net, "aten::isnan" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_fx_backend + def test_isnan(self, ie_device, precision, ir_version, input_tensor): + self.input_tensor = input_tensor + self._test(*self.create_model(), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_unary_ops.py b/tests/layer_tests/pytorch_tests/test_unary_ops.py index f495e7ba3d272f..1bda4ba4ba2e49 100644 --- a/tests/layer_tests/pytorch_tests/test_unary_ops.py +++ b/tests/layer_tests/pytorch_tests/test_unary_ops.py @@ -71,6 +71,7 @@ def _prepare_input(self): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.int8, torch.uint8, torch.int32, torch.int64]) @pytest.mark.parametrize("op_type", [ From 4830b785548d9e8d74f7b8d38fa0e46a7307726a Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Tue, 5 Mar 2024 16:32:29 -0800 Subject: [PATCH 12/25] TorchFX: Additional op unit tests --- .../pytorch_tests/test_batch_norm.py | 2 +- .../test_deformable_convolution.py | 1 + .../pytorch_tests/test_fake_quantize.py | 55 +++++++++++++++++++ tests/layer_tests/pytorch_tests/test_flip.py | 3 +- tests/layer_tests/pytorch_tests/test_full.py | 4 ++ .../pytorch_tests/test_leaky_relu.py | 1 + tests/layer_tests/pytorch_tests/test_log.py | 3 +- .../pytorch_tests/test_logical_ops.py | 3 +- .../pytorch_tests/test_masked_fill.py | 2 +- tests/layer_tests/pytorch_tests/test_mean.py | 1 + .../layer_tests/pytorch_tests/test_min_max.py | 5 +- .../layer_tests/pytorch_tests/test_pooling.py | 2 + tests/layer_tests/pytorch_tests/test_pow.py | 2 + .../layer_tests/pytorch_tests/test_repeat.py | 3 + .../pytorch_tests/test_roi_align.py | 1 + tests/layer_tests/pytorch_tests/test_roll.py | 7 ++- tests/layer_tests/pytorch_tests/test_rsub.py | 3 +- .../layer_tests/pytorch_tests/test_scatter.py | 1 + tests/layer_tests/pytorch_tests/test_sign.py | 1 + tests/layer_tests/pytorch_tests/test_sort.py | 1 + tests/layer_tests/pytorch_tests/test_split.py | 3 + tests/layer_tests/pytorch_tests/test_topk.py | 1 + tests/layer_tests/pytorch_tests/test_trilu.py | 4 +- .../layer_tests/pytorch_tests/test_unfold.py | 1 + .../pytorch_tests/test_var_mean.py | 4 +- 25 files changed, 102 insertions(+), 12 deletions(-) diff --git a/tests/layer_tests/pytorch_tests/test_batch_norm.py b/tests/layer_tests/pytorch_tests/test_batch_norm.py index 2275f53960ae45..81157bfe1e0557 100644 --- a/tests/layer_tests/pytorch_tests/test_batch_norm.py +++ b/tests/layer_tests/pytorch_tests/test_batch_norm.py @@ -62,4 +62,4 @@ def forward(self, x): @pytest.mark.precommit_torch_export def test_batch_norm(self, weights, bias, eps, train, running_stats, ie_device, precision, ir_version, kwargs_to_prepare_input): self._test(*self.create_model(weights, bias, eps, train, running_stats), - ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=False, use_mo_convert=False) \ No newline at end of file + ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=False, use_mo_convert=False) diff --git a/tests/layer_tests/pytorch_tests/test_deformable_convolution.py b/tests/layer_tests/pytorch_tests/test_deformable_convolution.py index fa4293b275e6c7..2986cf89a4baa2 100644 --- a/tests/layer_tests/pytorch_tests/test_deformable_convolution.py +++ b/tests/layer_tests/pytorch_tests/test_deformable_convolution.py @@ -170,6 +170,7 @@ def forward(self, x): @pytest.mark.parametrize("mask", [True, False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_deformable_convolution2d(self, params, bias, mask, ie_device, precision, ir_version): self._test( *self.create_model(**params, bias=bias, mask=mask), ie_device, precision, ir_version, trace_model=True diff --git a/tests/layer_tests/pytorch_tests/test_fake_quantize.py b/tests/layer_tests/pytorch_tests/test_fake_quantize.py index 3146ac87b90087..ba8cb3eb76e1d2 100644 --- a/tests/layer_tests/pytorch_tests/test_fake_quantize.py +++ b/tests/layer_tests/pytorch_tests/test_fake_quantize.py @@ -37,6 +37,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize( "scale, zero_point, quant_min, quant_max", [ @@ -61,6 +62,60 @@ def test_fake_quantize_per_tensor_affine( freeze_model=False ) +class TestFakeQuantizePerTensorAffineCacheMaskTensorQParams(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(3, 2, 2).astype(np.float32),) + + def create_model(self, scale, zero_point, quant_min, quant_max): + class _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(torch.nn.Module): + def __init__(self, scale, zero_point, quant_min, quant_max): + super(_fake_quantize_per_tensor_affine_cachemask_tensor_qparams, self).__init__() + self.scale = torch.tensor(scale) + self.zero_point = torch.tensor(zero_point) + self.fake_quant_enabled = torch.tensor(1) + self.quant_min = quant_min + self.quant_max = quant_max + + def forward(self, x): + return torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams( + x, self.scale, self.zero_point, self.fake_quant_enabled, self.quant_min, self.quant_max + ) + + ref_net = None + + return ( + _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(scale, zero_point, quant_min, quant_max), + ref_net, + "aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + ) + + @pytest.mark.nightly + @pytest.mark.precommit + #@pytest.mark.precommit_fx_backend + @pytest.mark.parametrize( + "scale, zero_point, quant_min, quant_max", + [ + (1.0, 1, 0, 255), + (0.01, 0, 0, 255), + (-0.01, 0, 0, 255), + (0.5, 0, -128, 127), + (0.5, -1, -128, 127), + (1.0, 0, 0, 127), + ], + ) + @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', + reason='Ticket - 122715') + def test__fake_quantize_per_tensor_affine_cachemask_tensor_qparams( + self, ie_device, precision, ir_version, scale, zero_point, quant_min, quant_max + ): + self._test( + *self.create_model(scale, zero_point, quant_min, quant_max), + ie_device, + precision, + ir_version, + freeze_model=False + ) + class TestFakeQuantizePerChannelAffine(PytorchLayerTest): def _prepare_input(self): diff --git a/tests/layer_tests/pytorch_tests/test_flip.py b/tests/layer_tests/pytorch_tests/test_flip.py index df390eb7caf001..4cc5f2a57cb701 100644 --- a/tests/layer_tests/pytorch_tests/test_flip.py +++ b/tests/layer_tests/pytorch_tests/test_flip.py @@ -37,8 +37,9 @@ def forward_out(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("axis", [[0], [1], [-1], [1, 2], [2, 3], [1, 2, 3]]) @pytest.mark.parametrize("out", [skip_if_export(True), False]) @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "uint8"]) def test_flip(self, axis, out, dtype, ie_device, precision, ir_version): - self._test(*self.create_model(axis, out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "dtype": dtype}) \ No newline at end of file + self._test(*self.create_model(axis, out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "dtype": dtype}) diff --git a/tests/layer_tests/pytorch_tests/test_full.py b/tests/layer_tests/pytorch_tests/test_full.py index 3264217c2b658d..5253c15f75e980 100644 --- a/tests/layer_tests/pytorch_tests/test_full.py +++ b/tests/layer_tests/pytorch_tests/test_full.py @@ -94,6 +94,7 @@ def test_full(self, shape, value, ie_device, precision, ir_version): @pytest.mark.parametrize("dtype", ["int8", "int32", "int64", "float32", "float64"]) @pytest.mark.parametrize("with_names", [True, False]) @pytest.mark.nightly + @pytest.mark.precommit_fx_backend def test_full_dtype(self, shape, value, dtype, with_names, ie_device, precision, ir_version): self._test(*self.create_model(shape, dtype=dtype, use_dtype=True, with_names=with_names), ie_device, precision, ir_version, kwargs_to_prepare_input={'value': value}) @@ -279,6 +280,7 @@ def forward(self, input_t: torch.Tensor, x: float): @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_full_like(self, shape, value, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={'value': value, 'shape': shape}) @@ -346,6 +348,7 @@ def forward(self, input_tensor: torch.Tensor, x: float): @pytest.mark.parametrize("value,input_dtype", [(0, np.uint8), (1, np.int32), (-1, np.float32), (0.5, np.float64)]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_new_full(self, shape, value, input_dtype, ie_device, precision, ir_version): self._test(*self.create_model(shape), ie_device, precision, ir_version, kwargs_to_prepare_input={'value': value, 'input_dtype': input_dtype}, use_convert_model=True) @@ -475,6 +478,7 @@ def forward(self, x): @pytest.mark.parametrize("op_type", ["aten::zeros", "aten::ones", "aten::zeros_like", "aten::ones_like"]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_fill(self, op_type, shape, ie_device, precision, ir_version): self._test(*self.create_model(op_type), ie_device, precision, ir_version, kwargs_to_prepare_input={'shape': shape}) diff --git a/tests/layer_tests/pytorch_tests/test_leaky_relu.py b/tests/layer_tests/pytorch_tests/test_leaky_relu.py index 2ef80dd388ae89..ee390548d4099e 100644 --- a/tests/layer_tests/pytorch_tests/test_leaky_relu.py +++ b/tests/layer_tests/pytorch_tests/test_leaky_relu.py @@ -32,5 +32,6 @@ def forward(self, x): @pytest.mark.parametrize("inplace", [skip_if_export(True), False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_leaky_relu(self, alpha, inplace, ie_device, precision, ir_version): self._test(*self.create_model(alpha, inplace), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_log.py b/tests/layer_tests/pytorch_tests/test_log.py index 8d595e82e82166..4cab2bf8b0460b 100644 --- a/tests/layer_tests/pytorch_tests/test_log.py +++ b/tests/layer_tests/pytorch_tests/test_log.py @@ -40,6 +40,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize(("op", "input_dtype"), [["log", "float32"], ["log", "int32"], @@ -57,4 +58,4 @@ def test_log(self, op, input_dtype, ie_device, precision, ir_version): if PytorchLayerTest.use_torch_export() and op[-1] == "_": pytest.skip(reason="export fails for inplace") self._test(*self.create_model(op), ie_device, precision, - ir_version, kwargs_to_prepare_input={"dtype": input_dtype}) \ No newline at end of file + ir_version, kwargs_to_prepare_input={"dtype": input_dtype}) diff --git a/tests/layer_tests/pytorch_tests/test_logical_ops.py b/tests/layer_tests/pytorch_tests/test_logical_ops.py index 210fd1a4bdb690..842d895542afb9 100644 --- a/tests/layer_tests/pytorch_tests/test_logical_ops.py +++ b/tests/layer_tests/pytorch_tests/test_logical_ops.py @@ -53,6 +53,7 @@ def forward_not_out(self, tensor_a, out): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("op_type", ["and", "or", "not", "xor"]) @pytest.mark.parametrize("first_dtype", ["bool", "int32", 'int8', 'float32']) @pytest.mark.parametrize("second_dtype", ["bool", "int32", 'int8', 'float32']) @@ -61,4 +62,4 @@ def test_logical(self, op_type, out, first_dtype, second_dtype, ie_device, preci self._test(*self.create_model(op_type, out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "unary": op_type == "not", - "first_dtype": first_dtype, "second_dtype": second_dtype}) \ No newline at end of file + "first_dtype": first_dtype, "second_dtype": second_dtype}) diff --git a/tests/layer_tests/pytorch_tests/test_masked_fill.py b/tests/layer_tests/pytorch_tests/test_masked_fill.py index 0f934843b077e8..8f2f109ec8021e 100644 --- a/tests/layer_tests/pytorch_tests/test_masked_fill.py +++ b/tests/layer_tests/pytorch_tests/test_masked_fill.py @@ -54,6 +54,7 @@ def forward(self, x, mask): @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_masked_fill(self, value, mask_fill, mask_dtype, input_dtype, inplace, ie_device, precision, ir_version): self._test(*self.create_model(value, inplace), ie_device, precision, ir_version, @@ -67,7 +68,6 @@ def test_masked_fill(self, value, mask_fill, mask_dtype, input_dtype, inplace, i @pytest.mark.parametrize("mask_dtype", [np.uint8, np.int32]) # np.float32 incorrectly casted to bool @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.nightly - @pytest.mark.precommit def test_masked_fill_non_bool_mask(self, value, mask_fill, mask_dtype, input_dtype, inplace, ie_device, precision, ir_version): self._test(*self.create_model(value, inplace), ie_device, precision, ir_version, diff --git a/tests/layer_tests/pytorch_tests/test_mean.py b/tests/layer_tests/pytorch_tests/test_mean.py index af381fa19bb7d3..28297e449c93cc 100644 --- a/tests/layer_tests/pytorch_tests/test_mean.py +++ b/tests/layer_tests/pytorch_tests/test_mean.py @@ -80,6 +80,7 @@ def forward_out(self, x, out): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_sum(self, axes, keep_dim, dtype, out, ie_device, precision, ir_version): if PytorchLayerTest.use_torch_export() and out: pytest.skip(reason="export fails for out") diff --git a/tests/layer_tests/pytorch_tests/test_min_max.py b/tests/layer_tests/pytorch_tests/test_min_max.py index 8008d725db3d63..27cb46ffb7a425 100644 --- a/tests/layer_tests/pytorch_tests/test_min_max.py +++ b/tests/layer_tests/pytorch_tests/test_min_max.py @@ -76,6 +76,7 @@ def forward(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_reduce_min_max(self, axes, keep_dims, op_type, ie_device, precision, ir_version): self._test(*self.create_model(op_type, axes, keep_dims, single_input=True), ie_device, precision, ir_version) @@ -86,6 +87,7 @@ def test_reduce_min_max(self, axes, keep_dims, op_type, ie_device, precision, ir @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_min_max(self, op_type, first_input_dtype, second_input_dtype, ie_device, precision, ir_version): self._test(*self.create_model(op_type, None, None, single_input=False, dtypes=(first_input_dtype, second_input_dtype)), ie_device, precision, ir_version, kwargs_to_prepare_input= @@ -266,6 +268,7 @@ def forward(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_minimum_maximum( self, op_type, first_input_dtype, second_input_dtype, ie_device, precision, ir_version ): @@ -342,4 +345,4 @@ def test_amin_amax(self, op_type, input_dtype, axis, keep_dims, out, ie_device, self._test(*self.create_model(op_type, axis, keep_dims, out), ie_device, precision, ir_version, kwargs_to_prepare_input= {"input_dtype": input_dtype, "out": out, "axes": axis, "keep_dims": keep_dims} - ) \ No newline at end of file + ) diff --git a/tests/layer_tests/pytorch_tests/test_pooling.py b/tests/layer_tests/pytorch_tests/test_pooling.py index f0ea018a552856..51a21af520f38d 100644 --- a/tests/layer_tests/pytorch_tests/test_pooling.py +++ b/tests/layer_tests/pytorch_tests/test_pooling.py @@ -232,6 +232,7 @@ def test_max_pool1d_indices(self, params, ceil_mode, dilation, ie_device, precis @pytest.mark.parametrize("dilation", [1, 2]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_max_pool2d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version): @@ -248,6 +249,7 @@ def test_max_pool2d_indices(self, params, ceil_mode, dilation, ie_device, preci @pytest.mark.parametrize("dilation", [1, 2]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_max_pool3d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_pow.py b/tests/layer_tests/pytorch_tests/test_pow.py index b3424dfc3be695..c684d74fbb5ea8 100644 --- a/tests/layer_tests/pytorch_tests/test_pow.py +++ b/tests/layer_tests/pytorch_tests/test_pow.py @@ -47,6 +47,7 @@ def forward_inplace(self, input_data, exponent): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_pow(self, inplace, ie_device, precision, ir_version, test_input): if inplace and PytorchLayerTest.use_torch_export(): pytest.skip(reason="export fails for inplace") @@ -109,6 +110,7 @@ def forward3(self, lhs, rhs): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_pow_mixed_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape): self.lhs_type = lhs_type self.lhs_shape = lhs_shape diff --git a/tests/layer_tests/pytorch_tests/test_repeat.py b/tests/layer_tests/pytorch_tests/test_repeat.py index bc7949eb091c30..6f6c952ebc9a9e 100644 --- a/tests/layer_tests/pytorch_tests/test_repeat.py +++ b/tests/layer_tests/pytorch_tests/test_repeat.py @@ -30,6 +30,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_repeat(self, repeats, ie_device, precision, ir_version): self._test(*self.create_model(repeats), ie_device, precision, ir_version) @@ -56,6 +57,7 @@ def forward(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_repeat(self, repeats, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"repeats_shape": repeats}) @@ -79,5 +81,6 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_repeat_t5(self, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, use_convert_model=True) diff --git a/tests/layer_tests/pytorch_tests/test_roi_align.py b/tests/layer_tests/pytorch_tests/test_roi_align.py index 574741aaa26db0..63ace627463a67 100644 --- a/tests/layer_tests/pytorch_tests/test_roi_align.py +++ b/tests/layer_tests/pytorch_tests/test_roi_align.py @@ -52,6 +52,7 @@ def forward(self, input_tensor, rois): @pytest.mark.parametrize('aligned', (True, False)) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_roi_align(self, ie_device, precision, ir_version, input_shape, boxes, output_size, spatial_scale, sampling_ratio, aligned): self.input_tensor = np.random.randn(*input_shape).astype(np.float32) diff --git a/tests/layer_tests/pytorch_tests/test_roll.py b/tests/layer_tests/pytorch_tests/test_roll.py index 7c4c6f2831717a..27ed2044ae7bac 100644 --- a/tests/layer_tests/pytorch_tests/test_roll.py +++ b/tests/layer_tests/pytorch_tests/test_roll.py @@ -18,12 +18,12 @@ class aten_roll(torch.nn.Module): def __init__(self, shifts, dim=None): super(aten_roll, self).__init__() self.dim = dim - self.shits = shifts + self.shifts = shifts def forward(self, x): if self.dim is not None: - return torch.roll(x, self.shits, self.dim) - return torch.roll(x, self.shits) + return torch.roll(x, self.shifts, self.dim) + return torch.roll(x, self.shifts) ref_net = None @@ -38,5 +38,6 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_roll(self, shifts, dim, ie_device, precision, ir_version): self._test(*self.create_model(shifts, dim), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_rsub.py b/tests/layer_tests/pytorch_tests/test_rsub.py index 08a9372582b07d..306ee34b0f13d8 100644 --- a/tests/layer_tests/pytorch_tests/test_rsub.py +++ b/tests/layer_tests/pytorch_tests/test_rsub.py @@ -104,9 +104,10 @@ def forward2(self, lhs, rhs:int): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_rsub_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type): self.lhs_type = lhs_type self.lhs_shape = lhs_shape self.rhs_type = rhs_type self._test(*self.create_model(lhs_type, rhs_type), - ie_device, precision, ir_version) \ No newline at end of file + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_scatter.py b/tests/layer_tests/pytorch_tests/test_scatter.py index 620d6d5c0ed0c3..5ed7d4452d4c17 100644 --- a/tests/layer_tests/pytorch_tests/test_scatter.py +++ b/tests/layer_tests/pytorch_tests/test_scatter.py @@ -91,6 +91,7 @@ def _forward_inplace_reduce(self, x: torch.Tensor): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dim", [1, -1, 0]) @pytest.mark.parametrize( "index", diff --git a/tests/layer_tests/pytorch_tests/test_sign.py b/tests/layer_tests/pytorch_tests/test_sign.py index bcb3ba2b27443b..7e3ae4a360c4cc 100644 --- a/tests/layer_tests/pytorch_tests/test_sign.py +++ b/tests/layer_tests/pytorch_tests/test_sign.py @@ -45,6 +45,7 @@ def forward_out(self, x, out): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("input_type", ["zeros", "positive", "negative", "mixed"]) @pytest.mark.parametrize("out", [True, False]) def test_sign(self, input_type, out, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_sort.py b/tests/layer_tests/pytorch_tests/test_sort.py index 28ff2b7d485e56..6dc629b2068660 100644 --- a/tests/layer_tests/pytorch_tests/test_sort.py +++ b/tests/layer_tests/pytorch_tests/test_sort.py @@ -78,6 +78,7 @@ def forward(self, input_tensor): ]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_sort(self, input_shape, descending, stable, ie_device, precision, ir_version): self.input_tensor = [] if type(input_shape) is list: diff --git a/tests/layer_tests/pytorch_tests/test_split.py b/tests/layer_tests/pytorch_tests/test_split.py index 8d03760260c312..de9fe551fd68ca 100644 --- a/tests/layer_tests/pytorch_tests/test_split.py +++ b/tests/layer_tests/pytorch_tests/test_split.py @@ -61,6 +61,7 @@ def forward(self, input): @pytest.mark.parametrize("getitem", [-5, -2, -1, 0, 1, 4]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_split_getitem(self, params, getitem, ie_device, precision, ir_version): (self.split_param, self.axis) = params self.getitem = getitem @@ -70,6 +71,7 @@ def test_split_getitem(self, params, getitem, ie_device, precision, ir_version): @pytest.mark.parametrize("params", test_cases) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_split_listunpack(self, params, ie_device, precision, ir_version): (self.split_param, self.axis) = params self._test( @@ -99,6 +101,7 @@ def forward(self, x, y): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_split_with_sizes(self, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True) diff --git a/tests/layer_tests/pytorch_tests/test_topk.py b/tests/layer_tests/pytorch_tests/test_topk.py index 1b657f25ade1a5..c88b413bb46ec9 100644 --- a/tests/layer_tests/pytorch_tests/test_topk.py +++ b/tests/layer_tests/pytorch_tests/test_topk.py @@ -61,6 +61,7 @@ def forward(self, input_tensor): ]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") == 'true', reason="Ticket - 115085") def test_topK(self, input_shape, k, dim, largest, sort, ie_device, precision, ir_version): self.input_tensor = np.random.randn(*input_shape).astype(np.float32) diff --git a/tests/layer_tests/pytorch_tests/test_trilu.py b/tests/layer_tests/pytorch_tests/test_trilu.py index 28842e101ce6da..dbafbc276af180 100644 --- a/tests/layer_tests/pytorch_tests/test_trilu.py +++ b/tests/layer_tests/pytorch_tests/test_trilu.py @@ -41,6 +41,7 @@ def forward(self, x): @pytest.mark.parametrize("op", ["triu", "tril"]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version): self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version, kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype}) @@ -89,6 +90,7 @@ def triu_(self, x): @pytest.mark.parametrize("op", ["triu", "tril", "triu_", "tril_"]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version): self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version, - kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype}) \ No newline at end of file + kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype}) diff --git a/tests/layer_tests/pytorch_tests/test_unfold.py b/tests/layer_tests/pytorch_tests/test_unfold.py index 37a0e467544e55..efc092e5f51f3c 100644 --- a/tests/layer_tests/pytorch_tests/test_unfold.py +++ b/tests/layer_tests/pytorch_tests/test_unfold.py @@ -39,6 +39,7 @@ def forward(self, input_tensor): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_unfold(self, ie_device, precision, ir_version, dimension, size, step, input_shape): self.input_tensor = np.random.randn(*input_shape).astype(np.float32) self._test(*self.create_model(dimension, size, step), diff --git a/tests/layer_tests/pytorch_tests/test_var_mean.py b/tests/layer_tests/pytorch_tests/test_var_mean.py index fddf7457749096..d62d846f63dc41 100644 --- a/tests/layer_tests/pytorch_tests/test_var_mean.py +++ b/tests/layer_tests/pytorch_tests/test_var_mean.py @@ -52,6 +52,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("unbiased", [True, False]) @pytest.mark.parametrize("op_type", ["var", "var_mean", "std", "std_mean"]) @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', @@ -61,6 +62,7 @@ def test_op2args(self, unbiased, op_type, ie_device, precision, ir_version): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("unbiased", [False, True]) @pytest.mark.parametrize("dim", [None, 0, 1, 2, 3, -1, -2, (0, 1), (-1, -2), (0, 1, -1), (0, 1, 2, 3)]) @pytest.mark.parametrize("keepdim", [True, False]) @@ -68,4 +70,4 @@ def test_op2args(self, unbiased, op_type, ie_device, precision, ir_version): @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_op(self, unbiased, dim, keepdim, op_type, ie_device, precision, ir_version): - self._test(*self.create_model(unbiased, dim, keepdim, two_args_case=False, op_type=op_type), ie_device, precision, ir_version) \ No newline at end of file + self._test(*self.create_model(unbiased, dim, keepdim, two_args_case=False, op_type=op_type), ie_device, precision, ir_version) From 3f95242a21ccf9934efe57d23a1cea31bf1eb615 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Wed, 6 Mar 2024 00:03:27 -0800 Subject: [PATCH 13/25] TorchFX: Additonal unit tests --- .../layer_tests/pytorch_tests/test_addcmul.py | 1 + tests/layer_tests/pytorch_tests/test_all.py | 2 + tests/layer_tests/pytorch_tests/test_any.py | 37 +++++++++++++++++ .../layer_tests/pytorch_tests/test_arange.py | 6 +++ .../pytorch_tests/test_argmax_argmin.py | 1 + .../pytorch_tests/test_as_strided.py | 3 ++ .../pytorch_tests/test_bitwise_ops.py | 2 + tests/layer_tests/pytorch_tests/test_clamp.py | 3 ++ .../pytorch_tests/test_constant_pad_nd.py | 37 +++++++++++++++++ tests/layer_tests/pytorch_tests/test_copy.py | 3 +- .../pytorch_tests/test_fake_quantize.py | 3 +- tests/layer_tests/pytorch_tests/test_flip.py | 1 - .../layer_tests/pytorch_tests/test_pooling.py | 2 + .../pytorch_tests/test_select_scatter.py | 39 ++++++++++++++++++ .../pytorch_tests/test_slice_scatter.py | 41 +++++++++++++++++++ 15 files changed, 178 insertions(+), 3 deletions(-) create mode 100644 tests/layer_tests/pytorch_tests/test_any.py create mode 100644 tests/layer_tests/pytorch_tests/test_constant_pad_nd.py create mode 100644 tests/layer_tests/pytorch_tests/test_select_scatter.py create mode 100644 tests/layer_tests/pytorch_tests/test_slice_scatter.py diff --git a/tests/layer_tests/pytorch_tests/test_addcmul.py b/tests/layer_tests/pytorch_tests/test_addcmul.py index 5ac9aa51b5f4a7..8861ac7ee9a099 100644 --- a/tests/layer_tests/pytorch_tests/test_addcmul.py +++ b/tests/layer_tests/pytorch_tests/test_addcmul.py @@ -47,6 +47,7 @@ def forward(self, x, y, z): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_addcmul(self, input_type, value, ie_device, precision, ir_version): self.input_type = input_type self._test(*self.create_model(value), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_all.py b/tests/layer_tests/pytorch_tests/test_all.py index d8b90c97d356c2..36c12b74cf548d 100644 --- a/tests/layer_tests/pytorch_tests/test_all.py +++ b/tests/layer_tests/pytorch_tests/test_all.py @@ -77,6 +77,7 @@ def _prepare_input(self, out=False): @pytest.mark.parametrize("out", [True, False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_all_noparams(self, input_shape, d_type, out, ie_device, precision, ir_version): if type(input_shape) is list: self.input_tensor = np.random.randint(0, 2, input_shape, dtype=d_type) @@ -104,6 +105,7 @@ def test_all_noparams(self, input_shape, d_type, out, ie_device, precision, ir_v @pytest.mark.parametrize("out", [True, False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() in ('Darwin', 'Linux') and platform.machine() in ('arm', 'armv7l', 'aarch64', 'arm64', 'ARM64'), diff --git a/tests/layer_tests/pytorch_tests/test_any.py b/tests/layer_tests/pytorch_tests/test_any.py new file mode 100644 index 00000000000000..387d89778a939a --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_any.py @@ -0,0 +1,37 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestAny(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return ((np.random.randint(2, size=(3,3,10,10)) > 0),) + + def create_model(self, dim, keep_dim): + + import torch + class aten_any(torch.nn.Module): + def __init__(self, dim=None, keep_dim=None): + super(aten_any, self).__init__() + self.dim = dim + self.keep_dim = keep_dim + + def forward(self, x): + return torch.any(x, dim=self.dim, keepdim=self.keep_dim) + + + ref_net = None + + return aten_any(dim, keep_dim), ref_net, "aten::any" + + @pytest.mark.parametrize(("dim", "keep_dim"), + [(0, False), (0, True), (-1, True)]) + + @pytest.mark.precommit_fx_backend + def test_any(self, dim, keep_dim, ie_device, precision, ir_version): + self._test(*self.create_model(dim, keep_dim), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_arange.py b/tests/layer_tests/pytorch_tests/test_arange.py index 9374b140fe893b..91fda699d5a94f 100644 --- a/tests/layer_tests/pytorch_tests/test_arange.py +++ b/tests/layer_tests/pytorch_tests/test_arange.py @@ -109,6 +109,7 @@ def forward(self, x, y, z, d): @pytest.mark.nightly @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8", "uin8"]) @pytest.mark.parametrize("end", [1, 2, 3]) @pytest.mark.parametrize("use_out", [skip_if_export(True), False]) @@ -117,6 +118,7 @@ def test_arange_end_only(self, dtype, end, use_out, ie_device, precision, ir_ver kwargs_to_prepare_input={"end": end}) @pytest.mark.nightly + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"]) @pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)]) def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_version): @@ -125,6 +127,7 @@ def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_vers @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"]) @pytest.mark.parametrize("start,end,step", [(0, 1, 1), (-2, 1, 1.25), (1, -5, -1), (1, 10, 2), (-1, -5, -2)]) def test_arange_start_end_step(self, dtype, end, start, step, ie_device, precision, ir_version): @@ -133,6 +136,7 @@ def test_arange_start_end_step(self, dtype, end, start, step, ie_device, precisi @pytest.mark.nightly @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8", "uint8"]) @pytest.mark.parametrize("end", [1, 2, 3]) def test_arange_end_only_with_prim_dtype(self, dtype, end, ie_device, precision, ir_version): @@ -140,6 +144,7 @@ def test_arange_end_only_with_prim_dtype(self, dtype, end, ie_device, precision, kwargs_to_prepare_input={"end": end, "ref_dtype": dtype}) @pytest.mark.nightly + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"]) @pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)]) def test_arange_start_end_with_prim_dtype(self, dtype, end, start, ie_device, precision, ir_version): @@ -148,6 +153,7 @@ def test_arange_start_end_with_prim_dtype(self, dtype, end, start, ie_device, pr @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"]) @pytest.mark.parametrize("start,end,step", [(0, 1, 1), (-2, 1, 1.25), (1, -5, -1), (1, 10, 2), (-1, -5, -2)]) def test_arange_start_end_step_with_prim_dtype(self, dtype, end, start, step, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_argmax_argmin.py b/tests/layer_tests/pytorch_tests/test_argmax_argmin.py index 3b7ba0486a4d1e..3a8a61befba0df 100644 --- a/tests/layer_tests/pytorch_tests/test_argmax_argmin.py +++ b/tests/layer_tests/pytorch_tests/test_argmax_argmin.py @@ -74,6 +74,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() in ('Darwin', 'Linux') and platform.machine() in ('arm', 'armv7l', 'aarch64', 'arm64', 'ARM64'), diff --git a/tests/layer_tests/pytorch_tests/test_as_strided.py b/tests/layer_tests/pytorch_tests/test_as_strided.py index fbacb1c81adf61..d9c8fe3cbb723a 100644 --- a/tests/layer_tests/pytorch_tests/test_as_strided.py +++ b/tests/layer_tests/pytorch_tests/test_as_strided.py @@ -41,6 +41,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_as_strided(self, size, stride, offset, ie_device, precision, ir_version): self._test(*self.create_model(size, stride, offset), ie_device, precision, ir_version, trace_model=True) @@ -92,6 +93,7 @@ def forward_size_const(self, x, size_shape_tensor, stride_shape_tensor): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_as_strided_list_construct(self, size, stride, offset, mode, ie_device, precision, ir_version): inp_kwargs = {"size_shape_tensor": size, "stride_shape_tensor": stride} self._test( @@ -124,5 +126,6 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_as_strided_lf(self, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, freeze_model=False) diff --git a/tests/layer_tests/pytorch_tests/test_bitwise_ops.py b/tests/layer_tests/pytorch_tests/test_bitwise_ops.py index 5d2b040b33bdc9..55626c51ea2718 100644 --- a/tests/layer_tests/pytorch_tests/test_bitwise_ops.py +++ b/tests/layer_tests/pytorch_tests/test_bitwise_ops.py @@ -55,6 +55,7 @@ def forward_not_out(self, tensor_a, out): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("op_type", ["and", "or", "not", "xor"]) @pytest.mark.parametrize("lhs_dtype", ["bool", "int32", "uint8", "int64"]) @pytest.mark.parametrize("rhs_dtype", ["bool", "int32", "uint8", "int64"]) @@ -107,6 +108,7 @@ def forward(self, lhs, rhs): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("lhs_dtype", ["bool", "int32"]) @pytest.mark.parametrize("rhs_dtype", ["bool", "int32"]) @pytest.mark.parametrize( diff --git a/tests/layer_tests/pytorch_tests/test_clamp.py b/tests/layer_tests/pytorch_tests/test_clamp.py index c98489034d2cae..3c9430d1a56bbe 100644 --- a/tests/layer_tests/pytorch_tests/test_clamp.py +++ b/tests/layer_tests/pytorch_tests/test_clamp.py @@ -48,6 +48,7 @@ def forward_clip_(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_clamp(self, minimum, maximum, as_tensors, op_type, ie_device, precision, ir_version): self._test(*self.create_model(minimum, maximum, as_tensors, op_type), ie_device, precision, ir_version) @@ -76,6 +77,7 @@ def forward(self, x): @pytest.mark.parametrize("minimum", [0., 1., -1., 0.5, 2]) @pytest.mark.parametrize("as_tensor", [True, False]) @pytest.mark.nightly + @pytest.mark.precommit_fx_backend def test_clamp_min(self, minimum, as_tensor, ie_device, precision, ir_version): self._test(*self.create_model(minimum, as_tensor), ie_device, precision, ir_version, use_convert_model=True, trace_model=True) @@ -106,6 +108,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_clamp(self, maximum, as_tensor, ie_device, precision, ir_version): self._test(*self.create_model(maximum, as_tensor), ie_device, precision, ir_version, use_convert_model=True, trace_model=True) diff --git a/tests/layer_tests/pytorch_tests/test_constant_pad_nd.py b/tests/layer_tests/pytorch_tests/test_constant_pad_nd.py new file mode 100644 index 00000000000000..7a92983bb1819d --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_constant_pad_nd.py @@ -0,0 +1,37 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestConstantPadND(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 5, 3, 4).astype(np.float32),) + + def create_model(self, pad, value): + + import torch + class aten_constant_pad_nd(torch.nn.Module): + def __init__(self, pad=None, value=None): + super(aten_constant_pad_nd, self).__init__() + self.pad = pad + self.value = value + + def forward(self, x): + return torch.constant_pad_nd(x, self.pad, self.value); + + + ref_net = None + + return aten_constant_pad_nd(pad, value), ref_net, "aten::constant_pad_nd" + + @pytest.mark.parametrize(("pad", "value"), + [((1,1,1,1), 0),((0,2,0,2), -1.0),((3,1,5,2), 0.5),((0,0,0,0), 0),]) + + @pytest.mark.precommit_fx_backend + def test_constant_pad_nd(self, pad, value, ie_device, precision, ir_version): + self._test(*self.create_model(pad, value), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_copy.py b/tests/layer_tests/pytorch_tests/test_copy.py index 6b4969c0277e8c..d714ef0057b249 100644 --- a/tests/layer_tests/pytorch_tests/test_copy.py +++ b/tests/layer_tests/pytorch_tests/test_copy.py @@ -28,6 +28,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("value", [1, [2.5], range(224)]) def test_copy_(self, value, ie_device, precision, ir_version): self._test(*self.create_model(value), ie_device, precision, ir_version) @@ -63,4 +64,4 @@ def forward_out(self, x, y): @pytest.mark.precommit @pytest.mark.parametrize("out", [True, False]) def test_copy_(self, out, ie_device, precision, ir_version): - self._test(*self.create_model(out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out}) \ No newline at end of file + self._test(*self.create_model(out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out}) diff --git a/tests/layer_tests/pytorch_tests/test_fake_quantize.py b/tests/layer_tests/pytorch_tests/test_fake_quantize.py index ba8cb3eb76e1d2..62b8bc898cd839 100644 --- a/tests/layer_tests/pytorch_tests/test_fake_quantize.py +++ b/tests/layer_tests/pytorch_tests/test_fake_quantize.py @@ -91,7 +91,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit - #@pytest.mark.precommit_fx_backend + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize( "scale, zero_point, quant_min, quant_max", [ @@ -146,6 +146,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize( "scale, zero_point, axis, quant_min, quant_max", [ diff --git a/tests/layer_tests/pytorch_tests/test_flip.py b/tests/layer_tests/pytorch_tests/test_flip.py index 4cc5f2a57cb701..56943e8494c53a 100644 --- a/tests/layer_tests/pytorch_tests/test_flip.py +++ b/tests/layer_tests/pytorch_tests/test_flip.py @@ -37,7 +37,6 @@ def forward_out(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export - @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("axis", [[0], [1], [-1], [1, 2], [2, 3], [1, 2, 3]]) @pytest.mark.parametrize("out", [skip_if_export(True), False]) @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "uint8"]) diff --git a/tests/layer_tests/pytorch_tests/test_pooling.py b/tests/layer_tests/pytorch_tests/test_pooling.py index 51a21af520f38d..d4c47c6537609f 100644 --- a/tests/layer_tests/pytorch_tests/test_pooling.py +++ b/tests/layer_tests/pytorch_tests/test_pooling.py @@ -157,6 +157,7 @@ def test_avg_pool1d(self, params, ceil_mode, count_include_pad, ie_device, preci @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version): @@ -169,6 +170,7 @@ def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, preci @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_avg_pool3d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_select_scatter.py b/tests/layer_tests/pytorch_tests/test_select_scatter.py new file mode 100644 index 00000000000000..02309739ec2dba --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_select_scatter.py @@ -0,0 +1,39 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + + +class TestSelectScatter(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 5, 3, 4).astype(np.float32),) + + def create_model(self, src, dim, index): + + import torch + class aten_select_scatter(torch.nn.Module): + def __init__(self, src=None, dim=None, index=None): + super(aten_select_scatter, self).__init__() + self.src = src + self.dim = dim + self.index = index + + def forward(self, x): + return torch.select_scatter(x, self.src, self.dim, self.index); + + + ref_net = None + + return aten_select_scatter(src, dim, index), ref_net, "aten::select_scatter" + + import torch + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize(("src", "dim", "index"), + [(torch.ones(2), 0, 0),]) + def aten_select_scatter(self, src, dim, index, ie_device, precision, ir_version): + self._test(*self.create_model(src, dim, index), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_slice_scatter.py b/tests/layer_tests/pytorch_tests/test_slice_scatter.py new file mode 100644 index 00000000000000..0d291f6bb4d3aa --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_slice_scatter.py @@ -0,0 +1,41 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + + +class TestSliceScatter(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 5, 3, 4).astype(np.float32),) + + def create_model(self, src, dim, start, end, step): + + import torch + class aten_slice_scatter(torch.nn.Module): + def __init__(self, src=None, dim=None, start=None, end=None, step=None): + super(aten_slice_scatter, self).__init__() + self.src = src + self.dim = dim + self.start = start + self.end = end + self.step = step + + def forward(self, x): + return torch.slice_scatter(x, src=self.src, dim=self.dim, start=self.start, end=self.end, step=self.step); + + + ref_net = None + + return aten_slice_scatter(src, dim, start, end, step), ref_net, "aten::slice_scatter" + + import torch + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize(("src", "dim", "start", "end", "step"), + [(torch.ones(2), 1, 1, 2, 1),]) + def aten_slice_scatter(self, src, dim, start, end, step, ie_device, precision, ir_version): + self._test(*self.create_model(src, dim, start, end, step), + ie_device, precision, ir_version) From e445f702d354b197dc1fbff38e9d578ed89e3aa6 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 11:42:25 -0800 Subject: [PATCH 14/25] Code style fix src/frontends/pytorch/src/op/cat.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/cat.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp index a200ad3f997acf..f8eb62b7babd9e 100644 --- a/src/frontends/pytorch/src/op/cat.cpp +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -116,8 +116,8 @@ OutputVector translate_stack_fx(const NodeContext& context) { // axis can be not present and that means that last input will have List type axis = context.const_input(num_elements - 1); } else { - auto stack_input = - context.mark_node(std::make_shared(context.get_input(static_cast(num_elements - 1)), dim)); + auto stack_input = context.mark_node( + std::make_shared(context.get_input(static_cast(num_elements - 1)), dim)); list_elems.push_back(stack_input); } return translate_cat_common(context, list_elems, axis, true); From d7d26e92c1cd321f5c3fd22562d8746016e6d796 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 11:42:43 -0800 Subject: [PATCH 15/25] Code style fix src/frontends/pytorch/src/op/any.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/any.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/frontends/pytorch/src/op/any.cpp b/src/frontends/pytorch/src/op/any.cpp index 32f4664e8b3195..2db41924c08943 100644 --- a/src/frontends/pytorch/src/op/any.cpp +++ b/src/frontends/pytorch/src/op/any.cpp @@ -3,8 +3,6 @@ // #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/op/reduce_logical_or.hpp" -#include "openvino/op/reshape.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/not_equal.hpp" #include "utils.hpp" From 936cd11ddbc1a675de8836a75aff504864affcc7 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 11:43:00 -0800 Subject: [PATCH 16/25] Code style fix src/frontends/pytorch/src/op/any.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/any.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/frontends/pytorch/src/op/any.cpp b/src/frontends/pytorch/src/op/any.cpp index 2db41924c08943..569c0df23fd6cb 100644 --- a/src/frontends/pytorch/src/op/any.cpp +++ b/src/frontends/pytorch/src/op/any.cpp @@ -5,6 +5,8 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/not_equal.hpp" +#include "openvino/op/reduce_logical_or.hpp" +#include "openvino/op/reshape.hpp" #include "utils.hpp" namespace ov { From 0f16d8f1fde5c5dac92dd00cd829fc2bc8680822 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 11:43:16 -0800 Subject: [PATCH 17/25] Code style fix src/frontends/pytorch/src/op/isinf.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/isinf.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/frontends/pytorch/src/op/isinf.cpp b/src/frontends/pytorch/src/op/isinf.cpp index 1a100391931f61..5b298316b690c2 100644 --- a/src/frontends/pytorch/src/op/isinf.cpp +++ b/src/frontends/pytorch/src/op/isinf.cpp @@ -2,8 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "openvino/op/add.hpp" - #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/is_inf.hpp" #include "utils.hpp" From 34c37b557d85301d2909e70196021a8611f6df46 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 11:43:35 -0800 Subject: [PATCH 18/25] Code style fix src/frontends/pytorch/src/op/isinf.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/isinf.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/frontends/pytorch/src/op/isinf.cpp b/src/frontends/pytorch/src/op/isinf.cpp index 5b298316b690c2..6fbb2880695dd0 100644 --- a/src/frontends/pytorch/src/op/isinf.cpp +++ b/src/frontends/pytorch/src/op/isinf.cpp @@ -3,6 +3,7 @@ // #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/add.hpp" #include "openvino/op/is_inf.hpp" #include "utils.hpp" From ff97c151af610f27503838f9434ef581cfda1425 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 11:43:47 -0800 Subject: [PATCH 19/25] Code style fix src/frontends/pytorch/src/op/isnan.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/isnan.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/frontends/pytorch/src/op/isnan.cpp b/src/frontends/pytorch/src/op/isnan.cpp index d995546b22b37e..99c2152a5c4af4 100644 --- a/src/frontends/pytorch/src/op/isnan.cpp +++ b/src/frontends/pytorch/src/op/isnan.cpp @@ -2,8 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "openvino/op/add.hpp" - #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/is_nan.hpp" #include "utils.hpp" From ccd47974265c8989abb73ab1aa3904f3256229e3 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 11:44:11 -0800 Subject: [PATCH 20/25] Code style fix src/frontends/pytorch/src/op/isnan.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/isnan.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/frontends/pytorch/src/op/isnan.cpp b/src/frontends/pytorch/src/op/isnan.cpp index 99c2152a5c4af4..d9df6124616268 100644 --- a/src/frontends/pytorch/src/op/isnan.cpp +++ b/src/frontends/pytorch/src/op/isnan.cpp @@ -3,6 +3,7 @@ // #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/add.hpp" #include "openvino/op/is_nan.hpp" #include "utils.hpp" From f5e8283683d107c57e12faf59bdff4f344cdaedd Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 11:44:32 -0800 Subject: [PATCH 21/25] Code style fix src/frontends/pytorch/src/op/topk.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/topk.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/topk.cpp b/src/frontends/pytorch/src/op/topk.cpp index 345cbabb86299d..50c870c8d4a5d4 100644 --- a/src/frontends/pytorch/src/op/topk.cpp +++ b/src/frontends/pytorch/src/op/topk.cpp @@ -5,7 +5,6 @@ #include "openvino/op/topk.hpp" #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/op/util/framework_node.hpp" #include "openvino/op/convert.hpp" #include "utils.hpp" From 238a3f0e9b862600f9c6791cf9c6c5351009c593 Mon Sep 17 00:00:00 2001 From: ynimmaga Date: Wed, 6 Mar 2024 15:03:45 -0800 Subject: [PATCH 22/25] Added embedding_bag and fixed unbind int --- .../pytorch/torchdynamo/op_support.py | 1 + .../pytorch/src/op/embedding_bag.cpp | 20 +++++++++++++++---- src/frontends/pytorch/src/op/split.cpp | 14 +++++++++---- src/frontends/pytorch/src/op_table.cpp | 2 ++ 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 0b4e69624c4aaa..d2fda2a67267c7 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -33,6 +33,7 @@ def __init__(self, options): "torch.ops.aten._adaptive_avg_pool2d.default": None, "torch.ops.aten._adaptive_avg_pool3d.default": None, "torch.ops.aten._convolution.default": None, + "torch.ops.aten._embedding_bag.default": None, "torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default": None, "torch.ops.aten._local_scalar_dense.default": None, "torch.ops.aten._log_softmax.default": None, diff --git a/src/frontends/pytorch/src/op/embedding_bag.cpp b/src/frontends/pytorch/src/op/embedding_bag.cpp index 4560ea2a09db4f..999f597f196324 100644 --- a/src/frontends/pytorch/src/op/embedding_bag.cpp +++ b/src/frontends/pytorch/src/op/embedding_bag.cpp @@ -15,10 +15,9 @@ namespace frontend { namespace pytorch { namespace op { -OutputVector translate_embedding_bag(const NodeContext& context) { +OutputVector translate_embedding_bag_common(const NodeContext& context) { // aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False, // per_sample_weights=None, include_last_offset=False, padding_idx=None) - num_inputs_check(context, 9, 9); // we have only EmbeddingBagSum case support, check it before translation auto mode = context.const_input(4); PYTORCH_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation"); @@ -43,7 +42,9 @@ OutputVector translate_embedding_bag(const NodeContext& context) { // with offsets case auto offsets = context.get_input(2); offsets = context.mark_node(std::make_shared(offsets, element::i32)); - auto include_last_offset = context.const_input(7); + bool include_last_offset = false; + if (!context.input_is_none(7)) + include_last_offset = context.const_input(7); PYTORCH_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported"); // no per_sample_wights if (context.input_is_none(6)) { @@ -63,7 +64,18 @@ OutputVector translate_embedding_bag(const NodeContext& context) { return {result, zero, zero, zero}; }; +OutputVector translate_embedding_bag(const NodeContext& context) { + num_inputs_check(context, 9, 9); + return translate_embedding_bag_common(context); +} + +OutputVector translate_embedding_bag_fx(const NodeContext& context) { + num_inputs_check(context, 7, 9); + ov::OutputVector output = translate_embedding_bag_common(context); + return {context.mark_node(make_list_construct(output))}; +} + } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/frontends/pytorch/src/op/split.cpp b/src/frontends/pytorch/src/op/split.cpp index b8345a0b4a9700..45689ccb695e59 100644 --- a/src/frontends/pytorch/src/op/split.cpp +++ b/src/frontends/pytorch/src/op/split.cpp @@ -37,12 +37,18 @@ OutputVector translate_chunk_fx(const NodeContext& context) { } OutputVector translate_unbind_int_fx(const NodeContext& context) { - num_inputs_check(context, 2, 3); + num_inputs_check(context, 1, 3); auto input = context.get_input(0); - auto dim = context.get_input(1); - auto dim_val = context.const_input(1); + Output dim; + int64_t dim_val = 0; + if (context.input_is_none(1)) { + dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + } else { + dim = context.get_input(1); + dim_val = context.const_input(1); + } + auto shape = input.get_shape(); - if (dim_val < 0) { dim_val = static_cast(shape.size()) + dim_val; } diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 3f91f55ae42272..16e879ead9cfe4 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -253,6 +253,7 @@ OP_CONVERTER(translate_constant_pad_nd_fx); OP_CONVERTER(translate_cumsum_fx); OP_CONVERTER(translate_chunk_fx); OP_CONVERTER(translate_div_fx); +OP_CONVERTER(translate_embedding_bag_fx); OP_CONVERTER(translate_expand_fx); OP_CONVERTER(translate_fake_quantize_per_channel_affine_fx); OP_CONVERTER(translate_fake_quantize_per_tensor_affine_fx); @@ -691,6 +692,7 @@ const std::map get_supported_ops_fx() { {"aten._adaptive_avg_pool2d.default", op::translate_adaptive_avg_pool2d}, {"aten._adaptive_avg_pool3d.default", op::translate_adaptive_avg_pool3d}, {"aten._convolution.default", op::translate_convolution}, + {"aten._embedding_bag.default", op::translate_embedding_bag_fx}, {"aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default", op::translate_fake_quantize_per_tensor_affine_fx}, {"aten._local_scalar_dense.default", op::skip_node}, From 255a78b4e0d1dd1c718615845a2c06a9b814f285 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 15:19:44 -0800 Subject: [PATCH 23/25] Code style fix src/frontends/pytorch/src/op/split.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/split.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/split.cpp b/src/frontends/pytorch/src/op/split.cpp index 45689ccb695e59..f43d8258859dd4 100644 --- a/src/frontends/pytorch/src/op/split.cpp +++ b/src/frontends/pytorch/src/op/split.cpp @@ -47,7 +47,6 @@ OutputVector translate_unbind_int_fx(const NodeContext& context) { dim = context.get_input(1); dim_val = context.const_input(1); } - auto shape = input.get_shape(); if (dim_val < 0) { dim_val = static_cast(shape.size()) + dim_val; From e65934edcc72b8b5234465f58fec44b1b9bab036 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Wed, 6 Mar 2024 18:04:25 -0800 Subject: [PATCH 24/25] TorchFX: Unit test enabled for embedding_bag --- tests/layer_tests/pytorch_tests/test_embedding_bag.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/layer_tests/pytorch_tests/test_embedding_bag.py b/tests/layer_tests/pytorch_tests/test_embedding_bag.py index d0c6d0c532856f..907a6a5609b3ed 100644 --- a/tests/layer_tests/pytorch_tests/test_embedding_bag.py +++ b/tests/layer_tests/pytorch_tests/test_embedding_bag.py @@ -42,6 +42,7 @@ def forward_offsets_per_sample_weights(self, indicies, weight, offsets, per_samp @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("indicies_dtype", ["int", "int32"]) @pytest.mark.parametrize("per_sample_weights", [True, False]) @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', @@ -86,6 +87,7 @@ def forward_per_sample_weights(self, indicies, weight, per_sample_wights): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("indicies_size", [[1, 1], [2, 5], [3, 10], [4, 7]]) @pytest.mark.parametrize("indicies_dtype", ["int", "int32"]) @pytest.mark.parametrize("per_sample_weights", [True, False]) From f9d478c3ccbcfffdfad22735b7aacc462056284d Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Wed, 6 Mar 2024 18:07:18 -0800 Subject: [PATCH 25/25] TorchFX: bitwise_right_shift is temporarily removed --- .../src/openvino/frontend/pytorch/torchdynamo/op_support.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index d2fda2a67267c7..a64ae17a329a30 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -78,7 +78,6 @@ def __init__(self, options): "torch.ops.aten.bitwise_and.Tensor": None, "torch.ops.aten.bitwise_not.default": None, "torch.ops.aten.bitwise_or.Tensor": None, - "torch.ops.aten.bitwise_right_shift.Tensor": None, "torch.ops.aten.bitwise_xor.Tensor": None, "torch.ops.aten.bmm.default": None, "torch.ops.aten.cat.default": None,