Skip to content

Commit

Permalink
[PT FE] Support aten::_weight_norm and aten::full with scalar size (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#21160)

* Support aten::_weight_norm and aten::full with scalar size

* Add op_table changes

* Add comments
  • Loading branch information
mvafin authored Nov 17, 2023
1 parent 44d56b9 commit 276153d
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ using namespace ov::op;

namespace {
Output<Node> base_translate_full(const NodeContext& context, const Output<Node>& sizes, const Output<Node>& value) {
if (is_empty_list(sizes)) {
return value;
}
return context.mark_node(std::make_shared<v3::Broadcast>(value, sizes));
}

Expand Down
29 changes: 29 additions & 0 deletions src/frontends/pytorch/src/op/norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/not_equal.hpp"
Expand Down Expand Up @@ -151,6 +154,32 @@ OutputVector translate_norm(const NodeContext& context) {
return {res};
};

OutputVector translate_weight_norm(const NodeContext& context) {
// aten::_weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor
num_inputs_check(context, 3, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
Output<Node> dim;
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(x, element::i32));
auto rank = context.mark_node(std::make_shared<v3::ShapeOf>(input_shape, element::i32));
rank = context.mark_node(std::make_shared<v0::Squeeze>(rank, zero));
if (context.input_is_none(2)) {
dim = context.mark_node(std::make_shared<v0::Range>(zero, rank, one));
} else {
dim = context.get_input(2);
auto dims_before = context.mark_node(std::make_shared<v0::Range>(zero, dim, one));
auto dim_next = context.mark_node(std::make_shared<v1::Add>(dim, one));
auto dims_after = context.mark_node(std::make_shared<v0::Range>(dim_next, rank, one));
dim = context.mark_node(std::make_shared<v0::Concat>(OutputVector{dims_before, dims_after}, 0));
}
Output<Node> res;
auto norm = context.mark_node(std::make_shared<v4::ReduceL2>(x, dim, true));
auto y_norm = context.mark_node(std::make_shared<v1::Divide>(y, norm));
return {context.mark_node(std::make_shared<v1::Multiply>(x, y_norm))};
};

OutputVector translate_linalg_vector_norm(const NodeContext& context) {
// aten::linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType?
// dtype=None) -> Tensor
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ OP_CONVERTER(translate_upsample_nearest3d);
OP_CONVERTER(translate_upsample_trilinear3d);
OP_CONVERTER(translate_var);
OP_CONVERTER(translate_var_mean);
OP_CONVERTER(translate_weight_norm);
OP_CONVERTER(translate_where);
OP_CONVERTER(translate_zeros);
OP_CONVERTER(translate_zeros_like);
Expand Down Expand Up @@ -244,6 +245,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::_native_multi_head_attention", op::translate_native_multi_head_attention},
{"aten::_set_item", op::translate_set_item},
{"aten::_shape_as_tensor", op::translate_shape_as_tensor},
{"aten::_weight_norm", op::translate_weight_norm},
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
{"aten::acos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acos>},
{"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acos>>},
Expand Down
7 changes: 7 additions & 0 deletions src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ Output<Node> concat_list_construct(const Output<Node>& input) {
return input;
}

bool is_empty_list(const Output<Node>& input) {
if (const auto list_construct = cast_fw_node(input.get_node_shared_ptr(), "prim::ListConstruct")) {
return list_construct->get_input_size() == 0;
}
return false;
}

namespace {
std::shared_ptr<PtFrameworkNode> create_fw_node_with_exception(const NodeContext& context,
const ov::OutputVector& inputs,
Expand Down
5 changes: 5 additions & 0 deletions src/frontends/pytorch/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ op::PadType convert_pad(const std::string& pt_pad);

Output<Node> concat_list_construct(const Output<Node>& input);

/// \brief Checks if input represents empty list.
/// \param input Input to check.
/// \return true if input is empty list, false - if input is non-empty or non-list.
bool is_empty_list(const Output<Node>& input);

OutputVector make_framework_node_ignore_bodies(const NodeContext& context, const std::string& exception);
OutputVector make_framework_node(const NodeContext& context, const std::string& exception);

Expand Down
3 changes: 1 addition & 2 deletions tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def use_torch_compile_backend():
assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag"
quant_size = kwargs['quant_size']
for i in range(len(infer_res)):
cur_fw_res = flatten_fw_res[i].contiguous().numpy(
) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
cur_fw_res = flatten_fw_res[i].contiguous().numpy(force=True) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
if np.array(cur_fw_res).size == 0:
continue
cur_ov_res = infer_res[compiled.output(i)]
Expand Down
16 changes: 16 additions & 0 deletions tests/layer_tests/pytorch_tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ def test_norm_tensor(self, ie_device, precision, ir_version, p, dim, keepdim):
self._test(*self.create_model_tensor_norm(p, dim, keepdim),
ie_device, precision, ir_version)

class TestWeightNorm(PytorchLayerTest):

def _prepare_input(self):
return (np.random.randn(1, 60, 20).astype(np.float32),)

def create_model(self):
from torch import nn
from torch.nn.utils import weight_norm

return weight_norm(nn.Linear(20, 40), name='weight'), None, "aten::_weight_norm"

@pytest.mark.nightly
@pytest.mark.precommit
def test_weight_norm(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, freeze_model=False)


class TestFrobeniusNorm(PytorchLayerTest):
def _prepare_input(self, out=False, dtype="float32"):
Expand Down

0 comments on commit 276153d

Please sign in to comment.