From 92649105ed7344f2eaa1c3305d1293eaac0b5040 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 2 Feb 2023 02:15:33 +0100 Subject: [PATCH] Add eltwise types resolving. Support big int constants. (#15415) * Add eltwise types resolving. Support big int constants. * Update src/bindings/python/src/openvino/frontend/pytorch/decoder.py * Small fix * Fix some cases * Add tests for add in different types * Add tests for mul * Add tests for sub and div * Small fixes * Return list handling (needed for empty lists) * Add test for empty list * Update src/frontends/pytorch/src/op/mul.cpp Co-authored-by: Roman Kazantsev * Use refs instead of ptrs * Apply suggestions from code review Co-authored-by: Roman Kazantsev * Apply code review suggestions * Fix code style * Add more eltwise ops --------- Co-authored-by: Roman Kazantsev --- .../src/openvino/frontend/pytorch/decoder.py | 78 ++-- src/frontends/pytorch/src/op/add.cpp | 13 +- src/frontends/pytorch/src/op/div.cpp | 37 +- src/frontends/pytorch/src/op/pow.cpp | 25 ++ src/frontends/pytorch/src/op/sub.cpp | 13 +- src/frontends/pytorch/src/op_table.cpp | 19 +- src/frontends/pytorch/src/utils.cpp | 101 +++++ src/frontends/pytorch/src/utils.hpp | 19 + .../py_frontend_tests/test_torch_decoder.py | 362 ++++++++++++++++++ tests/layer_tests/pytorch_tests/test_add.py | 66 +++- .../pytorch_tests/test_comparision.py | 79 +++- tests/layer_tests/pytorch_tests/test_div.py | 101 +++-- tests/layer_tests/pytorch_tests/test_mul.py | 95 +++-- tests/layer_tests/pytorch_tests/test_pow.py | 62 +++ tests/layer_tests/pytorch_tests/test_sub.py | 71 +++- 15 files changed, 1009 insertions(+), 132 deletions(-) create mode 100644 src/frontends/pytorch/src/op/pow.cpp diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index 09b6287bf3d9a4..90faeb573a0547 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -16,7 +16,10 @@ def get_type_from_py_type(value): if isinstance(value, float): return OVType.f32 if isinstance(value, int): - return OVType.i32 + # Python int is 64 bit, but we will convert it to int32 except cases when it can't fit in 32 bits + if torch.iinfo(torch.int).min <= value <= torch.iinfo(torch.int).max: + return OVType.i32 + return OVType.i64 if isinstance(value, bool): return OVType.boolean return OVType.dynamic @@ -27,13 +30,13 @@ def ivalue_to_constant(ivalue): if ov_type.is_static(): return op.Constant(ov_type, Shape([]), [ivalue]).outputs() - if isinstance(ivalue, list): + if isinstance(ivalue, (list, tuple)): assert len(ivalue) > 0, "Can't deduce type for empty list" ov_type = get_type_from_py_type(ivalue[0]) assert ov_type.is_static(), "Can't deduce type for list" return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs() - if ivalue.type() in pt_to_ov_type_map: + if isinstance(ivalue, torch.Tensor) and ivalue.type() in pt_to_ov_type_map: try: ovshape = PartialShape(ivalue.size()) ovtype = pt_to_ov_type_map[ivalue.type()] @@ -46,6 +49,7 @@ def ivalue_to_constant(ivalue): ovshape = PartialShape(nvalues.shape) ov_const = op.Constant(ovtype, ovshape.get_shape(), nvalues.flatten().tolist()) return ov_const.outputs() + return None def get_value_from_getattr(getattr_node, self_module): @@ -69,25 +73,22 @@ def get_value_from_getattr(getattr_node, self_module): pt_to_ov_type_map = { "float": OVType.f32, "int": OVType.i32, + "bool": OVType.boolean, + "torch.float16": OVType.f16, "torch.float32": OVType.f32, + "torch.float64": OVType.f64, + "torch.uint8": OVType.u8, + "torch.int8": OVType.i8, "torch.int32": OVType.i32, - "torch.bool": OVType.boolean, "torch.int64": OVType.i64, + "torch.bool": OVType.boolean, + "torch.DoubleTensor": OVType.f64, "torch.FloatTensor": OVType.f32, "torch.IntTensor": OVType.i32, "torch.LongTensor": OVType.i64, "torch.BoolTensor": OVType.boolean, } -pt_to_py_type_map = { - "float": "float", - "int": "int", - "torch.float32": "float", - "torch.int32": "int", - "torch.int64": "int", - "torch.bool": "bool", -} - np_to_ov_type_map = { "float32": OVType.f32, "int32": OVType.i32, @@ -106,7 +107,7 @@ def __init__(self, pt_module, graph_element=None): self.graph_element = graph_element self.pt_module = pt_module - def inputs(self): + def inputs(self) -> list: return [x.unique() for x in self.graph_element.inputs()] def get_input(self, index: int): @@ -150,7 +151,7 @@ def _get_known_type_for_value(self, pt_type): # Not yet recognized return OVAny(OVType.dynamic) - def get_shape_for_value(self, value): + def get_shape_for_value(self, value: torch.Value): if value.isCompleteTensor(): ps = PartialShape(value.type().sizes()) return ps @@ -161,7 +162,7 @@ def get_shape_for_value(self, value): pass return PartialShape.dynamic() - def get_type_for_value(self, value): + def get_type_for_value(self, value: torch.Value): full_type = self._get_known_type_for_value(value.type()) return full_type @@ -184,46 +185,46 @@ def get_output_transpose_order(self, index: int) -> list: def get_subgraph_size(self) -> int: return len(self.get_subgraphs()) if hasattr(self.graph_element, "blocks") else 1 - def visit_subgraph(self, node_visitor): + def visit_subgraph(self, node_visitor) -> None: # make sure topological order is satisfied for node in self.graph_element.nodes(): decoder = TorchScriptPythonDecoder(self.pt_module, node) self.m_decoders.append(decoder) node_visitor(decoder) - def get_subgraphs(self): + def get_subgraphs(self) -> list: return list(self.graph_element.blocks()) - def get_subgraph_decoder(self, index): + def get_subgraph_decoder(self, index: int): decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index]) self.m_decoders.append(decoder) return decoder - def get_op_type(self): + def get_op_type(self) -> str: return self.graph_element.kind() - def get_schema(self): + def get_schema(self) -> str: return self.graph_element.schema() - def outputs(self): + def outputs(self) -> list: return [x.unique() for x in self.graph_element.outputs()] - def _raw_outputs(self): + def _raw_outputs(self) -> list: return list(self.graph_element.outputs()) - def _raw_output(self, index): + def _raw_output(self, index: int): return self._raw_outputs()[index] - def _raw_inputs(self): + def _raw_inputs(self) -> list: return list(self.graph_element.inputs()) - def _raw_input(self, index): + def _raw_input(self, index: int): return self._raw_inputs()[index] def num_of_outputs(self): return len(self.outputs()) - def output(self, index): + def output(self, index: int): return self.outputs()[index] def mark_node(self, node): @@ -232,7 +233,7 @@ def mark_node(self, node): def try_decode_get_attr(self): pt_value = get_value_from_getattr(self.graph_element, self.pt_module) assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr" - if not isinstance(pt_value, torch.jit.ScriptModule) or isinstance(pt_value, torch.jit.TracedModule): + if not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)): return ivalue_to_constant(pt_value) else: return [] @@ -244,17 +245,10 @@ def as_constant(self): pt_type = pt_value.type() if isinstance(pt_type, torch.TensorType): - return self.as_constant_tensor(pt_value) + return self._as_constant_tensor(pt_value) if isinstance(pt_type, torch.ListType): - return self.as_constant_list(pt_value) - if str(pt_type) in ["torch.int32", "int"]: - return op.Constant(OVType.i32, Shape([]), [pt_value.toIValue()]).outputs() - if str(pt_type) in ["torch.float", "torch.FloatType", "float"]: - return op.Constant(OVType.f32, Shape([]), [pt_value.toIValue()]).outputs() - if str(pt_type) in ["torch.bool", "bool"]: - return op.Constant(OVType.boolean, Shape([]), [pt_value.toIValue()]).outputs() - - return None + return self._as_constant_list(pt_value) + return ivalue_to_constant(pt_value.toIValue()) def as_string(self): if not self.get_op_type() == "prim::Constant": @@ -265,7 +259,8 @@ def as_string(self): return pt_value.toIValue() return None - def as_constant_tensor(self, pt_value): + @staticmethod + def _as_constant_tensor(pt_value: torch.Value): ivalue = pt_value.toIValue() if pt_value.isCompleteTensor(): try: @@ -295,7 +290,8 @@ def as_constant_tensor(self, pt_value): return ivalue_to_constant(ivalue) return None - def as_constant_list(self, pt_value): + @staticmethod + def _as_constant_list(pt_value: torch.Value): # For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively # rewrite them in that part where constant attributes are queried pt_element_type = str(pt_value.type().getElementType()) @@ -308,7 +304,7 @@ def as_constant_list(self, pt_value): ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue) return ov_const.outputs() - def input_is_none(self, index): + def input_is_none(self, index: int) -> bool: if index >= len(self.inputs()) or self._raw_input(index) is None: return True else: diff --git a/src/frontends/pytorch/src/op/add.cpp b/src/frontends/pytorch/src/op/add.cpp index cf58c3e61d6c47..430b88d262c18c 100644 --- a/src/frontends/pytorch/src/op/add.cpp +++ b/src/frontends/pytorch/src/op/add.cpp @@ -2,8 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/op/add.hpp" + #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/opsets/opset10.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/multiply.hpp" #include "utils.hpp" namespace ov { @@ -12,12 +15,14 @@ namespace pytorch { namespace op { OutputVector translate_add(NodeContext& context) { + auto lhs = context.get_input(0); auto rhs = context.get_input(1); + align_eltwise_input_types(context, lhs, rhs); if (!context.input_is_none(2)) { - auto converted_alpha = std::make_shared(context.get_input(2), rhs); - rhs = std::make_shared(converted_alpha, rhs); + auto converted_alpha = context.mark_node(std::make_shared(context.get_input(2), rhs)); + rhs = context.mark_node(std::make_shared(converted_alpha, rhs)); } - return {context.mark_node(std::make_shared(context.get_input(0), rhs))}; + return {context.mark_node(std::make_shared(lhs, rhs))}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op/div.cpp b/src/frontends/pytorch/src/op/div.cpp index 0c5b3943511156..8fd170bed12edf 100644 --- a/src/frontends/pytorch/src/op/div.cpp +++ b/src/frontends/pytorch/src/op/div.cpp @@ -3,9 +3,14 @@ // #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/opsets/opset10.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/floor.hpp" #include "utils.hpp" +using namespace ov::op; + namespace ov { namespace frontend { namespace pytorch { @@ -14,21 +19,27 @@ namespace op { OutputVector translate_div(NodeContext& context) { auto x = context.get_input(0); auto y = context.get_input(1); - auto res = context.mark_node(std::make_shared(x, y, true)); + std::string rounding_mode = ""; if (!context.input_is_none(2)) { - auto rounding_mode = context.const_input(2); - if (rounding_mode == "floor") { - res = context.mark_node(std::make_shared(res)); - } else if (rounding_mode == "trunc") { - const auto convert = context.mark_node(std::make_shared(res, element::i64)); - res = context.mark_node(std::make_shared(convert, x)); - } else { - FRONT_END_OP_CONVERSION_CHECK(false, - "Openvino Pytorch Frontend doesn't support rounding mode ", - rounding_mode, - " for aten::div"); + rounding_mode = context.const_input(2); + } + if (rounding_mode.empty()) { + // if no rounding mode and both inputs are ints cast BOTH to fp32 + const auto x_dtype = x.get_element_type(); + const auto y_dtype = y.get_element_type(); + if (x_dtype.is_static() && x_dtype.is_integral() && y_dtype.is_static() && y_dtype.is_integral()) { + x = context.mark_node(std::make_shared(x, element::f32)); + y = context.mark_node(std::make_shared(y, element::f32)); } } + align_eltwise_input_types(context, x, y, true); + auto res = context.mark_node(std::make_shared(x, y, true)); + if (rounding_mode == "floor") { + res = context.mark_node(std::make_shared(res)); + } else if (rounding_mode == "trunc") { + const auto convert = context.mark_node(std::make_shared(res, element::i64)); + res = context.mark_node(std::make_shared(convert, x)); + } return {res}; }; diff --git a/src/frontends/pytorch/src/op/pow.cpp b/src/frontends/pytorch/src/op/pow.cpp new file mode 100644 index 00000000000000..46eb3489abaf04 --- /dev/null +++ b/src/frontends/pytorch/src/op/pow.cpp @@ -0,0 +1,25 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/power.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_pow(NodeContext& context) { + num_inputs_check(context, 1, 2); + auto lhs = context.get_input(0); + auto rhs = context.get_input(1); + align_eltwise_input_types(context, lhs, rhs, true); + return {context.mark_node(std::make_shared(lhs, rhs))}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/sub.cpp b/src/frontends/pytorch/src/op/sub.cpp index a7fa86663bd43c..fad5d19007bef3 100644 --- a/src/frontends/pytorch/src/op/sub.cpp +++ b/src/frontends/pytorch/src/op/sub.cpp @@ -3,9 +3,13 @@ // #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/opsets/opset10.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/subtract.hpp" #include "utils.hpp" +using namespace ov::op; + namespace ov { namespace frontend { namespace pytorch { @@ -14,13 +18,14 @@ namespace op { OutputVector translate_sub(NodeContext& context) { auto x = context.get_input(0); auto y = context.get_input(1); + align_eltwise_input_types(context, x, y); // default alpha is 1 so no need to multiply if alpha is not provided if (!context.input_is_none(2)) { auto alpha = context.get_input(2); - auto casted_alpha = context.mark_node(std::make_shared(alpha, y)); - y = context.mark_node(std::make_shared(casted_alpha, y)); + auto casted_alpha = context.mark_node(std::make_shared(alpha, y)); + y = context.mark_node(std::make_shared(casted_alpha, y)); } - return {context.mark_node(std::make_shared(x, y))}; + return {context.mark_node(std::make_shared(x, y))}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index b4b48668bc196c..de9a5a01275895 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -77,6 +77,7 @@ OP_CONVERTER(translate_numel); OP_CONVERTER(translate_ones); OP_CONVERTER(translate_ones_like); OP_CONVERTER(translate_pad); +OP_CONVERTER(translate_pow); OP_CONVERTER(translate_reciprocal); OP_CONVERTER(translate_relu6); OP_CONVERTER(translate_remainder); @@ -175,7 +176,7 @@ const std::map get_supported_ops() { {"aten::dropout_", op::skip_node}, {"aten::elu", op::translate_elu}, {"aten::embedding", op::translate_embedding}, - {"aten::eq", op::translate_1to1_match_2_inputs}, + {"aten::eq", op::translate_1to1_match_2_inputs_align_types}, {"aten::exp", op::translate_1to1_match_1_inputs}, {"aten::expand", op::translate_expand}, {"aten::expand_as", op::translate_expand_as}, @@ -191,8 +192,8 @@ const std::map get_supported_ops() { {"aten::gelu", op::translate_gelu}, {"aten::glu", op::translate_glu}, {"aten::group_norm", op::translate_group_norm}, - {"aten::ge", op::translate_1to1_match_2_inputs}, - {"aten::gt", op::translate_1to1_match_2_inputs}, + {"aten::ge", op::translate_1to1_match_2_inputs_align_types}, + {"aten::gt", op::translate_1to1_match_2_inputs_align_types}, {"aten::grid_sampler", op::translate_grid_sampler}, {"aten::hardsigmoid", op::translate_1to1_match_1_inputs}, {"aten::hardswish", op::translate_1to1_match_1_inputs}, @@ -209,8 +210,8 @@ const std::map get_supported_ops() { {"aten::leaky_relu_", op::inplace_op>}, {"aten::len", op::translate_len}, {"aten::linear", op::translate_linear}, - {"aten::le", op::translate_1to1_match_2_inputs}, - {"aten::lt", op::translate_1to1_match_2_inputs}, + {"aten::le", op::translate_1to1_match_2_inputs_align_types}, + {"aten::lt", op::translate_1to1_match_2_inputs_align_types}, {"aten::log", op::translate_log}, {"aten::log_", op::inplace_op}, {"aten::log2", op::translate_log2}, @@ -228,9 +229,9 @@ const std::map get_supported_ops() { {"aten::mm", op::translate_1to1_match_2_inputs}, {"aten::bmm", op::translate_1to1_match_2_inputs}, {"aten::matmul", op::translate_1to1_match_2_inputs}, - {"aten::mul", op::translate_1to1_match_2_inputs}, - {"aten::mul_", op::inplace_op>}, - {"aten::ne", op::translate_1to1_match_2_inputs}, + {"aten::mul", op::translate_1to1_match_2_inputs_align_types}, + {"aten::mul_", op::inplace_op>}, + {"aten::ne", op::translate_1to1_match_2_inputs_align_types}, {"aten::neg", op::translate_neg}, {"aten::norm", op::translate_norm}, {"aten::nonzero", op::translate_nonzero}, @@ -242,7 +243,7 @@ const std::map get_supported_ops() { {"aten::ones_like", op::translate_ones_like}, {"aten::pad", op::translate_pad}, {"aten::permute", op::translate_1to1_match_2_inputs}, - {"aten::pow", op::translate_1to1_match_2_inputs}, + {"aten::pow", op::translate_pow}, {"aten::reciprocal", op::translate_reciprocal}, {"aten::relu", op::translate_1to1_match_1_inputs}, {"aten::relu_", op::inplace_op>}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index e09dfccd2d69c2..2946eca98ef809 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -459,6 +459,107 @@ Any simplified_type_interpret(Any type) { return type; } +namespace { +std::unordered_map bit_to_float{ + {16, element::f16}, + {32, element::f32}, + {64, element::f64}, +}; +std::unordered_map bit_to_int{ + // {4, element::i4}, torch don't have int4 + {8, element::i8}, + {16, element::i16}, + {32, element::i32}, + {64, element::i64}, +}; +} // namespace + +void align_eltwise_input_types(const NodeContext& context, + ov::Output& lhs, + ov::Output& rhs, + bool align_scalars) { + const auto& lhs_type = lhs.get_element_type(); + const auto& rhs_type = rhs.get_element_type(); + if (lhs_type.is_dynamic() || rhs_type.is_dynamic()) { + // if any of types is not known, align to lhs type. + // TODO: can be fixed with special operation? + rhs = context.mark_node(std::make_shared(rhs, lhs)); + return; + } + + // Both types are static, align types. If float and int types are used convert int type to f32, after that align + // to the largest bitness, if both float or both int, just align bitness + if (lhs_type == rhs_type) + return; + + // if one of operands is scalar, the resulting type is taken from the other operand except when scalar is float + // type and other operand is int, in that case BOTH operands get fp32 type + const auto& lhs_rank = lhs.get_partial_shape().rank(); + const auto& rhs_rank = rhs.get_partial_shape().rank(); + // consider dynamic rank as non scalar + const auto is_lhs_scalar = lhs_rank.is_static() && lhs_rank.get_length() == 0; + const auto is_rhs_scalar = rhs_rank.is_static() && rhs_rank.get_length() == 0; + if (is_lhs_scalar && is_rhs_scalar) { + // if both scalar, align to lhs + rhs = context.mark_node(std::make_shared(rhs, lhs)); + return; + } + auto lhs_dst_type = lhs_type; + auto rhs_dst_type = rhs_type; + if (is_lhs_scalar) { + if (lhs_type.is_real() && !rhs_type.is_real()) { + // if div we need to also align float types to highest bitness regardless of scalar + if (!align_scalars) + lhs_dst_type = element::f32; + rhs_dst_type = element::f32; + } else { + lhs = context.mark_node(std::make_shared(lhs, rhs)); + return; + } + } else if (is_rhs_scalar) { + if (!lhs_type.is_real() && rhs_type.is_real()) { + lhs_dst_type = element::f32; + // if div we need to also align float types to highest bitness regardless of scalar + if (!align_scalars) + rhs_dst_type = element::f32; + } else { + rhs = context.mark_node(std::make_shared(rhs, lhs)); + return; + } + } + + if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) { + // Do nothing with bool + return; + } + + if (!lhs_dst_type.is_real() && rhs_dst_type.is_real()) { + lhs_dst_type = element::f32; + } else if (lhs_dst_type.is_real() && !rhs_dst_type.is_real()) { + rhs_dst_type = element::f32; + } + // Align bitness to higher + if (lhs_dst_type.bitwidth() != rhs_dst_type.bitwidth()) { + const auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth()); + element::Type* type_to_align = &lhs_dst_type; + if (rhs_dst_type.bitwidth() < dst_bitness) + type_to_align = &rhs_dst_type; + if (type_to_align->is_real()) { + *type_to_align = bit_to_float.at(dst_bitness); + } else { + *type_to_align = bit_to_int.at(dst_bitness); + } + } + + // Cast to destination types + if (lhs_dst_type != lhs_type) { + lhs = context.mark_node(std::make_shared(lhs, lhs_dst_type)); + } + if (rhs_dst_type != rhs_type) { + rhs = context.mark_node(std::make_shared(rhs, rhs_dst_type)); + } +} + } // namespace pytorch } // namespace frontend } // namespace ov diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 982023e407e9b1..d9af70cbd1b880 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -55,6 +55,11 @@ std::shared_ptr cast_fw_node(std::shared_ptr // TODO: Elimitate the need of this function by implementing more accurate custom data type handling Any simplified_type_interpret(Any type); +void align_eltwise_input_types(const NodeContext& context, + ov::Output& lhs, + ov::Output& rhs, + bool align_scalars = false); + namespace op { template OutputVector inplace_op(NodeContext& context) { @@ -87,6 +92,20 @@ OutputVector translate_1to1_match_2_inputs(NodeContext& context) { return {context.mark_node(std::make_shared(inputs[0], inputs[1]))}; } +template +OutputVector translate_1to1_match_2_inputs_align_types(NodeContext& context) { + auto inputs = context.inputs(); + FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 2, "Operation has less then 2 inputs."); + for (int i = 2; i < inputs.size(); i++) { + FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected."); + } + FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None."); + auto lhs = inputs[0]; + auto rhs = inputs[1]; + align_eltwise_input_types(context, lhs, rhs); + return {context.mark_node(std::make_shared(lhs, rhs))}; +} + inline OutputVector return_false_scalar(NodeContext& context) { return {context.mark_node(ov::op::v0::Constant::create(element::boolean, Shape{}, {false}))}; } diff --git a/tests/layer_tests/py_frontend_tests/test_torch_decoder.py b/tests/layer_tests/py_frontend_tests/test_torch_decoder.py index 279b649b43ba28..7dd416ab73e22e 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_decoder.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_decoder.py @@ -21,6 +21,7 @@ def get_scripted_model(model): model = torch.jit.script(model) model.eval() model = torch.jit.freeze(model) + print(model.inlined_graph) # will help debugging return model @@ -82,3 +83,364 @@ def test_pytorch_decoder_get_input_type_none(): assert isinstance(list(div_node.inputs())[2].type(), torch.NoneType) nc_decoder = TorchScriptPythonDecoder(model, div_node) assert isinstance(nc_decoder.get_input_type(2).value, DecoderType.PyNone) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_fp16_tensor(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class SomeTensor(torch.nn.Module): + def forward(self): + return torch.tensor([1, 2], dtype=torch.float16) + + model = get_scripted_model(SomeTensor()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.f16 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_fp32_tensor(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class SomeTensor(torch.nn.Module): + def forward(self): + return torch.tensor([1, 2], dtype=torch.float32) + + model = get_scripted_model(SomeTensor()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.f32 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_fp64_tensor(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class SomeTensor(torch.nn.Module): + def forward(self): + return torch.tensor([1, 2], dtype=torch.float64) + + model = get_scripted_model(SomeTensor()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.f64 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_bool_tensor(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class SomeTensor(torch.nn.Module): + def forward(self): + return torch.tensor([1, 0], dtype=torch.bool) + + model = get_scripted_model(SomeTensor()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.boolean + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_u8_tensor(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class SomeTensor(torch.nn.Module): + def forward(self): + return torch.tensor([1, 2], dtype=torch.uint8) + + model = get_scripted_model(SomeTensor()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.u8 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_i8_tensor(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class SomeTensor(torch.nn.Module): + def forward(self): + return torch.tensor([1, 2], dtype=torch.int8) + + model = get_scripted_model(SomeTensor()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.i8 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_i32_tensor(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class SomeTensor(torch.nn.Module): + def forward(self): + return torch.tensor([1, 2], dtype=torch.int) + + model = get_scripted_model(SomeTensor()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.i32 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_i64_tensor(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class SomeTensor(torch.nn.Module): + def forward(self): + return torch.tensor([1, 2], dtype=torch.int64) + + model = get_scripted_model(SomeTensor()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.i64 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_int64_max(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + + class I64MaxConst(torch.nn.Module): + def forward(self): + return 9223372036854775807 + + model = get_scripted_model(I64MaxConst()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + int64_const = consts[0] + print(int64_const) + nc_decoder = TorchScriptPythonDecoder(model, int64_const) + assert nc_decoder.as_constant() is not None + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_int_list(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class ListConst(torch.nn.Module): + def forward(self): + return [1, 2] + + model = get_scripted_model(ListConst()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + print(some_const) + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.i32 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_float_list(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class ListConst(torch.nn.Module): + def forward(self): + return [float(1), float(2)] + + model = get_scripted_model(ListConst()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + print(some_const) + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.f32 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_bool_list(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class ListConst(torch.nn.Module): + def forward(self): + return [True, False] + + model = get_scripted_model(ListConst()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + print(some_const) + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.boolean + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_int_tuple(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class ListConst(torch.nn.Module): + def forward(self): + return (1, 2) + + model = get_scripted_model(ListConst()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + print(some_const) + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.i32 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_float_tuple(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class ListConst(torch.nn.Module): + def forward(self): + return (float(1), float(2)) + + model = get_scripted_model(ListConst()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + print(some_const) + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.f32 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +@pytest.mark.xfail(reason="Bool tuple gets converted to i32 tuple.") +def test_pytorch_decoder_can_convert_bool_tuple(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class ListConst(torch.nn.Module): + def forward(self): + return (True, False) + + model = get_scripted_model(ListConst()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + print(some_const) + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.boolean + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_empty_list(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class aten_roll(torch.nn.Module): + def __init__(self, shifts): + super(aten_roll, self).__init__() + self.shits = shifts + + def forward(self, x): + # roll has optional input dim, which is empty int list by default + return torch.roll(x, self.shits) + + model = get_scripted_model(aten_roll(1)) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 1 + empty_const = consts[1] + print(empty_const) + nc_decoder = TorchScriptPythonDecoder(model, empty_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.i32 + assert ov_const[0].get_partial_shape() == PartialShape([0]) diff --git a/tests/layer_tests/pytorch_tests/test_add.py b/tests/layer_tests/pytorch_tests/test_add.py index 313a781a63a57e..7be3e8becb7855 100644 --- a/tests/layer_tests/pytorch_tests/test_add.py +++ b/tests/layer_tests/pytorch_tests/test_add.py @@ -10,7 +10,8 @@ @pytest.mark.parametrize('alpha', (-0.5, 0, 0.5, 1, 2)) @pytest.mark.parametrize('input_rhs', (np.random.randn(2, 5, 3, 4).astype(np.float32), - np.random.randn(1, 5, 3, 4).astype(np.float32), + np.random.randn( + 1, 5, 3, 4).astype(np.float32), np.random.randn(1).astype(np.float32))) class TestAdd(PytorchLayerTest): @@ -36,3 +37,66 @@ def forward(self, lhs, rhs): def test_add(self, ie_device, precision, ir_version, alpha, input_rhs): self.input_rhs = input_rhs self._test(*self.create_model(alpha), ie_device, precision, ir_version) + + +class TestAddTypes(PytorchLayerTest): + + def _prepare_input(self): + if len(self.lhs_shape) == 0: + return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),) + elif len(self.rhs_shape) == 0: + return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),) + return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(), + torch.randn(self.rhs_shape).to(self.rhs_type).numpy()) + + def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape): + + class aten_add(torch.nn.Module): + def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape): + super().__init__() + self.lhs_type = lhs_type + self.rhs_type = rhs_type + if len(lhs_shape) == 0: + self.forward = self.forward1 + elif len(rhs_shape) == 0: + self.forward = self.forward2 + else: + self.forward = self.forward3 + + def forward1(self, rhs): + return torch.add(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2) + + def forward2(self, lhs): + return torch.add(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), alpha=2) + + def forward3(self, lhs, rhs): + return torch.add(lhs.to(self.lhs_type), rhs.to(self.rhs_type), alpha=2) + + ref_net = None + + return aten_add(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::add" + + @pytest.mark.parametrize(("lhs_type", "rhs_type"), + [[torch.int32, torch.int64], + [torch.int32, torch.float32], + [torch.int32, torch.float64], + [torch.int64, torch.int32], + [torch.int64, torch.float32], + [torch.int64, torch.float64], + [torch.float32, torch.int32], + [torch.float32, torch.int64], + [torch.float32, torch.float64], + ]) + @pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]), + ([2, 3], []), + ([], [2, 3]), + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_add_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape): + self.lhs_type = lhs_type + self.lhs_shape = lhs_shape + self.rhs_type = rhs_type + self.rhs_shape = rhs_shape + self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_comparision.py b/tests/layer_tests/pytorch_tests/test_comparision.py index b667d6d2c95c9c..98134a274f7bdb 100644 --- a/tests/layer_tests/pytorch_tests/test_comparision.py +++ b/tests/layer_tests/pytorch_tests/test_comparision.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch from pytorch_layer_test_class import PytorchLayerTest @@ -12,8 +13,6 @@ def _prepare_input(self): return (np.random.randn(1, 3, 24, 24).astype(np.float32), np.random.randn(1, 3, 24, 24).astype(np.float32)) def create_model(self, op_type): - import torch - class aten_eq(torch.nn.Module): def forward(self, x, y): return x == y @@ -57,3 +56,79 @@ def forward(self, x, y): @pytest.mark.precommit def test_comp(self, op, ie_device, precision, ir_version): self._test(*self.create_model(op), ie_device, precision, ir_version) + + +class TestCompMixedTypes(PytorchLayerTest): + + def _prepare_input(self): + if len(self.lhs_shape) == 0: + return (torch.randint(0, 3, self.rhs_shape).to(self.rhs_type).numpy(),) + elif len(self.rhs_shape) == 0: + return (torch.randint(0, 3, self.lhs_shape).to(self.lhs_type).numpy(),) + return (torch.randint(0, 3, self.lhs_shape).to(self.lhs_type).numpy(), + torch.randint(0, 3, self.rhs_shape).to(self.rhs_type).numpy()) + + def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape, op): + + ops = { + "eq": torch.eq, + "ne": torch.ne, + "lt": torch.lt, + "gt": torch.gt, + "ge": torch.ge, + "le": torch.le + } + + op_fn = ops[op] + + class aten_comp(torch.nn.Module): + def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape, op_fn): + super().__init__() + self.lhs_type = lhs_type + self.rhs_type = rhs_type + self.op_fn = op_fn + if len(lhs_shape) == 0: + self.forward = self.forward1 + elif len(rhs_shape) == 0: + self.forward = self.forward2 + else: + self.forward = self.forward3 + + def forward1(self, rhs): + return self.op_fn(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type)) + + def forward2(self, lhs): + return self.op_fn(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type)) + + def forward3(self, lhs, rhs): + return self.op_fn(lhs.to(self.lhs_type), rhs.to(self.rhs_type)) + + ref_net = None + + return aten_comp(lhs_type, lhs_shape, rhs_type, rhs_shape, op_fn), ref_net, f"aten::{op}" + + @pytest.mark.parametrize(("lhs_type", "rhs_type"), + [[torch.int32, torch.int64], + [torch.int32, torch.float32], + [torch.int32, torch.float64], + [torch.int64, torch.int32], + [torch.int64, torch.float32], + [torch.int64, torch.float64], + [torch.float32, torch.int32], + [torch.float32, torch.int64], + [torch.float32, torch.float64], + ]) + @pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]), + ([2, 3], []), + ([], [2, 3]), + ]) + @pytest.mark.parametrize("op", ["eq", "ne", "lt", "gt", "le", "ge"]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_eq_mixed_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape, op): + self.lhs_type = lhs_type + self.lhs_shape = lhs_shape + self.rhs_type = rhs_type + self.rhs_shape = rhs_shape + self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape, op), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_div.py b/tests/layer_tests/pytorch_tests/test_div.py index af7b0ad75736cc..48eaba1a89bb3f 100644 --- a/tests/layer_tests/pytorch_tests/test_div.py +++ b/tests/layer_tests/pytorch_tests/test_div.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import torch from pytorch_layer_test_class import PytorchLayerTest @@ -12,7 +13,6 @@ def _prepare_input(self): return (self.input_array.astype(self.input_type), self.other_array.astype(self.other_type)) def create_model(self, rounding_mode): - import torch class aten_div(torch.nn.Module): def __init__(self, rounding_mode): @@ -26,34 +26,6 @@ def forward(self, input_tensor, other_tensor): return aten_div(rounding_mode), ref_net, "aten::div" - @pytest.mark.parametrize(("input_array", "other_array"), [ - [10 * np.random.rand(5, 5), np.random.uniform(low=1, high=5, size=(1))], - [10 * np.random.rand(5, 5, 1), np.random.uniform(low=1, high=5, size=(1))], - [10 * np.random.rand(1, 1, 5, 5), np.random.uniform( - low=1, high=5, size=(1))], - [10 * np.random.rand(5, 5, 1), np.random.uniform( - low=1, high=5, size=(5, 1))] - ]) - @pytest.mark.parametrize(("types"), [ - (np.float32, np.float32), - pytest.param((np.int32, np.float32), marks=pytest.mark.xfail), - pytest.param((np.float32, np.int32), marks=pytest.mark.xfail), - pytest.param((np.int32, np.int32), marks=pytest.mark.xfail) - ]) - @pytest.mark.parametrize('rounding_mode', ([ - None, - "floor", - "trunc" - ])) - @pytest.mark.nightly - def test_div(self, input_array, other_array, types, rounding_mode, ie_device, precision, ir_version): - self.input_array = input_array - self.input_type = types[0] - self.other_array = other_array - self.other_type = types[1] - self._test(*self.create_model(rounding_mode), - ie_device, precision, ir_version) - @pytest.mark.parametrize(("input_array", "other_array"), [ [np.array([0.7620, 2.5548, -0.5944, -0.7438, 0.9274]), np.array(0.5)], [np.array([[-0.3711, -1.9353, -0.4605, -0.2917], @@ -76,3 +48,74 @@ def test_div_pt_spec(self, input_array, other_array, rounding_mode, ie_device, p self.other_type = np.float32 self._test(*self.create_model(rounding_mode), ie_device, precision, ir_version) + + +class TestDivTypes(PytorchLayerTest): + + def _prepare_input(self): + if len(self.lhs_shape) == 0: + return (torch.randint(2, 5, self.rhs_shape).to(self.rhs_type).numpy(),) + elif len(self.rhs_shape) == 0: + return (10 * torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),) + return (10 * torch.randn(self.lhs_shape).to(self.lhs_type).numpy(), + torch.randint(2, 5, self.rhs_shape).to(self.rhs_type).numpy()) + + def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode): + + class aten_div(torch.nn.Module): + def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode): + super().__init__() + self.lhs_type = lhs_type + self.rhs_type = rhs_type + self.rm = rounding_mode + if len(lhs_shape) == 0: + self.forward = self.forward1 + elif len(rhs_shape) == 0: + self.forward = self.forward2 + else: + self.forward = self.forward3 + + def forward1(self, rhs): + return torch.div(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), rounding_mode=self.rm) + + def forward2(self, lhs): + return torch.div(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), rounding_mode=self.rm) + + def forward3(self, lhs, rhs): + return torch.div(lhs.to(self.lhs_type), rhs.to(self.rhs_type), rounding_mode=self.rm) + + ref_net = None + + return aten_div(lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode), ref_net, "aten::div" + + @pytest.mark.parametrize(("lhs_type", "rhs_type"), + [[torch.int32, torch.int64], + [torch.int32, torch.float32], + [torch.int32, torch.float64], + [torch.int64, torch.int32], + [torch.int64, torch.float32], + [torch.int64, torch.float64], + [torch.float32, torch.int32], + [torch.float32, torch.int64], + [torch.float32, torch.float64], + ]) + @pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]), + ([2, 3], []), + ([], [2, 3]), + ]) + @pytest.mark.parametrize('rounding_mode', ([ + None, + "floor", + "trunc" + ])) + @pytest.mark.nightly + @pytest.mark.precommit + def test_div_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode): + self.lhs_type = lhs_type + self.lhs_shape = lhs_shape + self.rhs_type = rhs_type + self.rhs_shape = rhs_shape + if rounding_mode == "floor" and not lhs_type.is_floating_point and not rhs_type.is_floating_point: + pytest.skip("Floor rounding mode and int inputs produce wrong results") + self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_mul.py b/tests/layer_tests/pytorch_tests/test_mul.py index 3276c42f4ed2d1..02a17e8c38d7d1 100644 --- a/tests/layer_tests/pytorch_tests/test_mul.py +++ b/tests/layer_tests/pytorch_tests/test_mul.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import torch from pytorch_layer_test_class import PytorchLayerTest @@ -12,7 +13,6 @@ def _prepare_input(self): return (self.input_array.astype(self.input_type), self.other_array.astype(self.other_type)) def create_model(self): - import torch class aten_mul(torch.nn.Module): def __init__(self): @@ -26,35 +26,78 @@ def forward(self, input_tensor, other_tensor): return aten_mul(), ref_net, "aten::mul" @pytest.mark.parametrize(("input_array", "other_array"), [ - [np.random.rand(1, 2), np.random.rand(2, 1)], - [np.random.rand(3, 1, 2), np.random.rand(3, 1, 2)], - [np.random.rand(4, 1, 1), np.random.rand(1, 1, 4)], - ]) - @pytest.mark.parametrize(("types"), [ - (np.float32, np.float32), - # Type promotion - pytest.param((np.int32, np.float32), marks=pytest.mark.xfail(reason="101869")), - pytest.param((np.float32, np.int32), marks=pytest.mark.xfail(reason="101869")), - pytest.param((np.int32, np.int32), marks=pytest.mark.xfail(reason="101869")) - ]) - @pytest.mark.nightly - def test_mul_random(self, input_array, other_array, types, ie_device, precision, ir_version): - self.input_array = input_array - self.input_type = types[0] - self.other_array = other_array - self.other_type = types[1] - self._test(*self.create_model(), ie_device, precision, ir_version) - - - @pytest.mark.parametrize(("input_array", "other_array"), [ - [np.array([ 0.2015, -0.4255, 2.6087]), np.array(100)], - [np.array([[ 1.1207], [-0.3137], [0.0700], [0.8378]]), np.array([[0.5146, 0.1216, -0.5244, 2.2382]])], + [np.array([0.2015, -0.4255, 2.6087]), np.array(100)], + [np.array([[1.1207], [-0.3137], [0.0700], [0.8378]]), + np.array([[0.5146, 0.1216, -0.5244, 2.2382]])], ]) @pytest.mark.nightly @pytest.mark.precommit def test_mul_pt_spec(self, input_array, other_array, ie_device, precision, ir_version): - self.input_array = input_array + self.input_array = input_array self.input_type = np.float32 self.other_array = other_array - self.other_type = np.float32 + self.other_type = np.float32 self._test(*self.create_model(), ie_device, precision, ir_version) + + +class TestMulTypes(PytorchLayerTest): + + def _prepare_input(self): + if len(self.lhs_shape) == 0: + return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),) + elif len(self.rhs_shape) == 0: + return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),) + return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(), + torch.randn(self.rhs_shape).to(self.rhs_type).numpy()) + + def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape): + + class aten_mul(torch.nn.Module): + def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape): + super().__init__() + self.lhs_type = lhs_type + self.rhs_type = rhs_type + if len(lhs_shape) == 0: + self.forward = self.forward1 + elif len(rhs_shape) == 0: + self.forward = self.forward2 + else: + self.forward = self.forward3 + + def forward1(self, rhs): + return torch.mul(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type)) + + def forward2(self, lhs): + return torch.mul(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type)) + + def forward3(self, lhs, rhs): + return torch.mul(lhs.to(self.lhs_type), rhs.to(self.rhs_type)) + + ref_net = None + + return aten_mul(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::mul" + + @pytest.mark.parametrize(("lhs_type", "rhs_type"), + [[torch.int32, torch.int64], + [torch.int32, torch.float32], + [torch.int32, torch.float64], + [torch.int64, torch.int32], + [torch.int64, torch.float32], + [torch.int64, torch.float64], + [torch.float32, torch.int32], + [torch.float32, torch.int64], + [torch.float32, torch.float64], + ]) + @pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]), + ([2, 3], []), + ([], [2, 3]), + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_mul_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape): + self.lhs_type = lhs_type + self.lhs_shape = lhs_shape + self.rhs_type = rhs_type + self.rhs_shape = rhs_shape + self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_pow.py b/tests/layer_tests/pytorch_tests/test_pow.py index 6266203b273a8f..b973106d0acbb3 100644 --- a/tests/layer_tests/pytorch_tests/test_pow.py +++ b/tests/layer_tests/pytorch_tests/test_pow.py @@ -42,3 +42,65 @@ def forward(self, input_data, exponent): def test_pow(self, ie_device, precision, ir_version, test_input): self.test_input = test_input self._test(*self.create_model(), ie_device, precision, ir_version) + + +class TestPowMixedTypes(PytorchLayerTest): + def _prepare_input(self): + if len(self.lhs_shape) == 0: + return (torch.randn(self.rhs_shape) * 2 + 0.6).to(self.rhs_type).numpy(), + elif len(self.rhs_shape) == 0: + return (torch.randint(1, 3, self.lhs_shape).to(self.lhs_type).numpy(),) + return (torch.randint(1, 3, self.lhs_shape).to(self.lhs_type).numpy(), + (torch.randn(self.rhs_shape) * 2 + 0.6).to(self.rhs_type).numpy()) + + def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape): + + class aten_pow(torch.nn.Module): + def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape): + super().__init__() + self.lhs_type = lhs_type + self.rhs_type = rhs_type + if len(lhs_shape) == 0: + self.forward = self.forward1 + elif len(rhs_shape) == 0: + self.forward = self.forward2 + else: + self.forward = self.forward3 + + def forward1(self, rhs): + return torch.pow(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type)) + + def forward2(self, lhs): + return torch.pow(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type)) + + def forward3(self, lhs, rhs): + return torch.pow(lhs.to(self.lhs_type), rhs.to(self.rhs_type)) + + ref_net = None + + return aten_pow(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::pow" + + @pytest.mark.parametrize(("lhs_type", "rhs_type"), + [[torch.int32, torch.int64], + [torch.int32, torch.float32], + [torch.int32, torch.float64], + [torch.int64, torch.int32], + [torch.int64, torch.float32], + [torch.int64, torch.float64], + [torch.float32, torch.int32], + [torch.float32, torch.int64], + [torch.float32, torch.float64], + ]) + @pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]), + ([2, 3], []), + ([], [2, 3]), + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_pow_mixed_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape): + self.lhs_type = lhs_type + self.lhs_shape = lhs_shape + self.rhs_type = rhs_type + self.rhs_shape = rhs_shape + self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_sub.py b/tests/layer_tests/pytorch_tests/test_sub.py index 95d3d0dc76314c..84244accc3b546 100644 --- a/tests/layer_tests/pytorch_tests/test_sub.py +++ b/tests/layer_tests/pytorch_tests/test_sub.py @@ -24,13 +24,78 @@ def forward(self, x, y, alpha: float): return aten_sub(), ref_net, "aten::sub" @pytest.mark.parametrize('input_data', [(np.random.randn(2, 3, 4).astype(np.float32), - np.random.randn(2, 3, 4).astype(np.float32), + np.random.randn( + 2, 3, 4).astype(np.float32), np.random.randn(1)), (np.random.randn(4, 2, 3).astype(np.float32), - np.random.randn(1, 2, 3).astype(np.float32), - np.random.randn(1)),]) + np.random.randn( + 1, 2, 3).astype(np.float32), + np.random.randn(1)), ]) @pytest.mark.nightly @pytest.mark.precommit def test_sub(self, ie_device, precision, ir_version, input_data): self.input_data = input_data self._test(*self.create_model(), ie_device, precision, ir_version) + + +class TestSubTypes(PytorchLayerTest): + + def _prepare_input(self): + if len(self.lhs_shape) == 0: + return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),) + elif len(self.rhs_shape) == 0: + return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),) + return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(), + torch.randn(self.rhs_shape).to(self.rhs_type).numpy()) + + def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape): + + class aten_sub(torch.nn.Module): + def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape): + super().__init__() + self.lhs_type = lhs_type + self.rhs_type = rhs_type + if len(lhs_shape) == 0: + self.forward = self.forward1 + elif len(rhs_shape) == 0: + self.forward = self.forward2 + else: + self.forward = self.forward3 + + def forward1(self, rhs): + return torch.sub(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2) + + def forward2(self, lhs): + return torch.sub(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), alpha=2) + + def forward3(self, lhs, rhs): + return torch.sub(lhs.to(self.lhs_type), rhs.to(self.rhs_type), alpha=2) + + ref_net = None + + return aten_sub(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::sub" + + @pytest.mark.parametrize(("lhs_type", "rhs_type"), + [[torch.int32, torch.int64], + [torch.int32, torch.float32], + # [torch.int32, torch.float64], fp64 produce ov error of eltwise constant fold + [torch.int64, torch.int32], + [torch.int64, torch.float32], + # [torch.int64, torch.float64], fp64 produce ov error of eltwise constant fold + [torch.float32, torch.int32], + [torch.float32, torch.int64], + # [torch.float32, torch.float64], fp64 produce ov error of eltwise constant fold + ]) + @pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]), + ([2, 3], []), + ([], [2, 3]), + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_sub_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape): + self.lhs_type = lhs_type + self.lhs_shape = lhs_shape + self.rhs_type = rhs_type + self.rhs_shape = rhs_shape + self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), + ie_device, precision, ir_version)