From 276153d1cee057a230f64de3ea95220be52ef333 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 17 Nov 2023 20:13:15 +0100 Subject: [PATCH] [PT FE] Support aten::_weight_norm and aten::full with scalar size (#21160) * Support aten::_weight_norm and aten::full with scalar size * Add op_table changes * Add comments --- src/frontends/pytorch/src/op/full.cpp | 3 ++ src/frontends/pytorch/src/op/norm.cpp | 29 +++++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 ++ src/frontends/pytorch/src/utils.cpp | 7 +++++ src/frontends/pytorch/src/utils.hpp | 5 ++++ .../pytorch_tests/pytorch_layer_test_class.py | 3 +- tests/layer_tests/pytorch_tests/test_norm.py | 16 ++++++++++ 7 files changed, 63 insertions(+), 2 deletions(-) diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index e8bfa1c7ce99d7..f7b3fdc44e6bc8 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -27,6 +27,9 @@ using namespace ov::op; namespace { Output base_translate_full(const NodeContext& context, const Output& sizes, const Output& value) { + if (is_empty_list(sizes)) { + return value; + } return context.mark_node(std::make_shared(value, sizes)); } diff --git a/src/frontends/pytorch/src/op/norm.cpp b/src/frontends/pytorch/src/op/norm.cpp index d3136b7e76ad48..3ba697fdef1bfa 100644 --- a/src/frontends/pytorch/src/op/norm.cpp +++ b/src/frontends/pytorch/src/op/norm.cpp @@ -4,9 +4,12 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/abs.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/convert_like.hpp" +#include "openvino/op/divide.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/not_equal.hpp" @@ -151,6 +154,32 @@ OutputVector translate_norm(const NodeContext& context) { return {res}; }; +OutputVector translate_weight_norm(const NodeContext& context) { + // aten::_weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor + num_inputs_check(context, 3, 3); + auto x = context.get_input(0); + auto y = context.get_input(1); + Output dim; + auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto input_shape = context.mark_node(std::make_shared(x, element::i32)); + auto rank = context.mark_node(std::make_shared(input_shape, element::i32)); + rank = context.mark_node(std::make_shared(rank, zero)); + if (context.input_is_none(2)) { + dim = context.mark_node(std::make_shared(zero, rank, one)); + } else { + dim = context.get_input(2); + auto dims_before = context.mark_node(std::make_shared(zero, dim, one)); + auto dim_next = context.mark_node(std::make_shared(dim, one)); + auto dims_after = context.mark_node(std::make_shared(dim_next, rank, one)); + dim = context.mark_node(std::make_shared(OutputVector{dims_before, dims_after}, 0)); + } + Output res; + auto norm = context.mark_node(std::make_shared(x, dim, true)); + auto y_norm = context.mark_node(std::make_shared(y, norm)); + return {context.mark_node(std::make_shared(x, y_norm))}; +}; + OutputVector translate_linalg_vector_norm(const NodeContext& context) { // aten::linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? // dtype=None) -> Tensor diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 452a8927627629..342778078bbd48 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -203,6 +203,7 @@ OP_CONVERTER(translate_upsample_nearest3d); OP_CONVERTER(translate_upsample_trilinear3d); OP_CONVERTER(translate_var); OP_CONVERTER(translate_var_mean); +OP_CONVERTER(translate_weight_norm); OP_CONVERTER(translate_where); OP_CONVERTER(translate_zeros); OP_CONVERTER(translate_zeros_like); @@ -244,6 +245,7 @@ const std::map get_supported_ops_ts() { {"aten::_native_multi_head_attention", op::translate_native_multi_head_attention}, {"aten::_set_item", op::translate_set_item}, {"aten::_shape_as_tensor", op::translate_shape_as_tensor}, + {"aten::_weight_norm", op::translate_weight_norm}, {"aten::abs", op::translate_1to1_match_1_inputs}, {"aten::acos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::acos_", op::inplace_op>}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 7decae35b30bbb..15b8c1cd6e07a7 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -191,6 +191,13 @@ Output concat_list_construct(const Output& input) { return input; } +bool is_empty_list(const Output& input) { + if (const auto list_construct = cast_fw_node(input.get_node_shared_ptr(), "prim::ListConstruct")) { + return list_construct->get_input_size() == 0; + } + return false; +} + namespace { std::shared_ptr create_fw_node_with_exception(const NodeContext& context, const ov::OutputVector& inputs, diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 20bae6fa62f5c3..41e19bf03f92b2 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -52,6 +52,11 @@ op::PadType convert_pad(const std::string& pt_pad); Output concat_list_construct(const Output& input); +/// \brief Checks if input represents empty list. +/// \param input Input to check. +/// \return true if input is empty list, false - if input is non-empty or non-list. +bool is_empty_list(const Output& input); + OutputVector make_framework_node_ignore_bodies(const NodeContext& context, const std::string& exception); OutputVector make_framework_node(const NodeContext& context, const std::string& exception); diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index f8b726c4c5f66d..adee1fb3ccc4c9 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -137,8 +137,7 @@ def use_torch_compile_backend(): assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag" quant_size = kwargs['quant_size'] for i in range(len(infer_res)): - cur_fw_res = flatten_fw_res[i].contiguous().numpy( - ) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i] + cur_fw_res = flatten_fw_res[i].contiguous().numpy(force=True) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i] if np.array(cur_fw_res).size == 0: continue cur_ov_res = infer_res[compiled.output(i)] diff --git a/tests/layer_tests/pytorch_tests/test_norm.py b/tests/layer_tests/pytorch_tests/test_norm.py index bbbdb3bae34293..c884e6c17ff042 100644 --- a/tests/layer_tests/pytorch_tests/test_norm.py +++ b/tests/layer_tests/pytorch_tests/test_norm.py @@ -72,6 +72,22 @@ def test_norm_tensor(self, ie_device, precision, ir_version, p, dim, keepdim): self._test(*self.create_model_tensor_norm(p, dim, keepdim), ie_device, precision, ir_version) +class TestWeightNorm(PytorchLayerTest): + + def _prepare_input(self): + return (np.random.randn(1, 60, 20).astype(np.float32),) + + def create_model(self): + from torch import nn + from torch.nn.utils import weight_norm + + return weight_norm(nn.Linear(20, 40), name='weight'), None, "aten::_weight_norm" + + @pytest.mark.nightly + @pytest.mark.precommit + def test_weight_norm(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, freeze_model=False) + class TestFrobeniusNorm(PytorchLayerTest): def _prepare_input(self, out=False, dtype="float32"):