From 838d792d961caf84954d9570ee12d3c352c55374 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 14 Jun 2023 11:28:19 +0400 Subject: [PATCH 01/11] [PT FE]: fix unflatten for list construct sizes (#18039) --- src/frontends/pytorch/src/op/unflatten.cpp | 3 ++ .../pytorch_tests/test_unflatten.py | 36 ++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/unflatten.cpp b/src/frontends/pytorch/src/op/unflatten.cpp index eff0a5130cc09a..673efbc1480161 100644 --- a/src/frontends/pytorch/src/op/unflatten.cpp +++ b/src/frontends/pytorch/src/op/unflatten.cpp @@ -25,6 +25,9 @@ OutputVector translate_unflatten(const NodeContext& context) { auto input = context.get_input(0); auto dim = context.get_input(1); auto sizes = context.get_input(2); + if (context.get_input_type(2).is()) { + sizes = concat_list_construct(sizes); + } auto input_shape = context.mark_node(std::make_shared(input, element::i32)); auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); diff --git a/tests/layer_tests/pytorch_tests/test_unflatten.py b/tests/layer_tests/pytorch_tests/test_unflatten.py index e260b125e11417..3f8e9de3a2b9f1 100644 --- a/tests/layer_tests/pytorch_tests/test_unflatten.py +++ b/tests/layer_tests/pytorch_tests/test_unflatten.py @@ -32,4 +32,38 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit def test_unflatten(self, dim, shape, dtype, ie_device, precision, ir_version): - self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype}) \ No newline at end of file + self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype}) + + +class TestUnflattenListSizes(PytorchLayerTest): + def _prepare_input(self, dtype): + return (np.random.uniform(0, 50, (6, 2, 4)).astype(dtype),) + + def create_model(self, dim): + import torch + + class aten_unflatten(torch.nn.Module): + def __init__(self, dim): + super(aten_unflatten, self).__init__() + self.dim = dim + + def forward(self, x): + dim1, dim2, dim3 = x.shape + if self.dim == 0: + sizes = [dim1, -1] + elif self.dim == 1: + sizes = [dim2 // 2, -1] + else: + sizes = [2, dim3 // 2, -1] + return x.unflatten(self.dim, sizes) + + ref_net = None + + return aten_unflatten(dim), ref_net, "aten::unflatten" + + @pytest.mark.parametrize("dim", [0, 1, 2]) + @pytest.mark.parametrize("dtype", ["float32", "int32"]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_unflatten(self, dim, dtype, ie_device, precision, ir_version): + self._test(*self.create_model(dim), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype}) \ No newline at end of file From 9754117a61f6a498f52502e7fc386df919832b45 Mon Sep 17 00:00:00 2001 From: Tatiana Savina Date: Wed, 14 Jun 2023 09:34:06 +0200 Subject: [PATCH 02/11] change classification notebook (#18037) --- docs/notebooks/001-hello-world-with-output.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/notebooks/001-hello-world-with-output.rst b/docs/notebooks/001-hello-world-with-output.rst index a571c294ff0e9f..44004196ea5f61 100644 --- a/docs/notebooks/001-hello-world-with-output.rst +++ b/docs/notebooks/001-hello-world-with-output.rst @@ -47,7 +47,7 @@ Load an Image # Reshape to model input shape. input_image = np.expand_dims(input_image, 0) - plt.imshow(image); + plt.imshow(image) From 67dc220d380cf3c59625890d51458004f199edca Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Wed, 14 Jun 2023 11:55:24 +0400 Subject: [PATCH 03/11] [Snippets] Added support of MatMuls with transposed inputs (#17819) --- .../snippets/pass/common_optimizations.hpp | 9 +- .../pass/explicit_transpose_matmul_inputs.hpp | 22 ++- .../src/pass/common_optimizations.cpp | 78 ++++++++-- .../pass/explicit_transpose_matmul_inputs.cpp | 140 ++++++++++-------- .../snippets/src/pass/mha_tokenization.cpp | 108 ++++++-------- .../tests/src/pass/mha_tokenization.cpp | 30 +++- .../shared_tests_instances/snippets/mha.cpp | 15 ++ .../plugin/shared/include/snippets/mha.hpp | 28 ++-- .../plugin/shared/src/snippets/mha.cpp | 58 ++++---- .../include/subgraph_mha.hpp | 25 ++++ .../src/subgraph_mha.cpp | 66 +++++++++ 11 files changed, 395 insertions(+), 184 deletions(-) diff --git a/src/common/snippets/include/snippets/pass/common_optimizations.hpp b/src/common/snippets/include/snippets/pass/common_optimizations.hpp index 08b07339f5fe40..2961603077f275 100644 --- a/src/common/snippets/include/snippets/pass/common_optimizations.hpp +++ b/src/common/snippets/include/snippets/pass/common_optimizations.hpp @@ -5,7 +5,8 @@ #pragma once #include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pattern/matcher.hpp" + +#include "snippets/op/subgraph.hpp" namespace ov { namespace snippets { @@ -15,6 +16,12 @@ class CommonOptimizations : public ov::pass::MatcherPass { public: OPENVINO_RTTI("CommonOptimizations", "0"); CommonOptimizations(); + +private: + // Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body + void ExtractConstants(const std::shared_ptr& subgraph); + // Move up unsupported Transposes on Parameter outputs from body + void ExtractUnsupportedTransposes(const std::shared_ptr& subgraph); }; } // namespace pass diff --git a/src/common/snippets/include/snippets/pass/explicit_transpose_matmul_inputs.hpp b/src/common/snippets/include/snippets/pass/explicit_transpose_matmul_inputs.hpp index dbad1a714b8271..378128d9014b37 100644 --- a/src/common/snippets/include/snippets/pass/explicit_transpose_matmul_inputs.hpp +++ b/src/common/snippets/include/snippets/pass/explicit_transpose_matmul_inputs.hpp @@ -5,7 +5,6 @@ #pragma once #include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pattern/matcher.hpp" namespace ov { namespace snippets { @@ -13,18 +12,25 @@ namespace pass { /** * @interface ExplicitTransposeMatMulInputs - * @brief At the moment Snippets supports Transpose only with order {0, 2, 3, 1}, - * so if there is pattern in graph: - * in0 Transpose{0, 2, 1, 3} - * \ / - * MatMul[false, true] - * We can set false in MatMul parameter `transposed_b` and - * change Transpose order to {0, 2, 3, 1} which is supported by Snippets + * @brief The pass extracts explicit Transpose node from MatMul with transposed_ and moves it to Parameter. + * If there is another Transpose, the pass fuses extracted Transpose and existing Transpose. + * For example, At the moment Snippets supports Transpose only with order {0, 2, 3, 1}, so if there is pattern in graph: + * in0 Transpose{0, 2, 1, 3} + * \ / + * MatMul[false, true] + * We can set `false` in MatMul parameter `transposed_b` and change Transpose order to {0, 2, 3, 1} which is supported by Snippets * @ingroup snippets */ class ExplicitTransposeMatMulInputs: public ov::pass::MatcherPass { public: + OPENVINO_RTTI("ExplicitTransposeMatMulInputs", "0"); ExplicitTransposeMatMulInputs(); + + // Return `True` if all inputs (except 0-th input) have scalar shape. Otherwise returns `False` + static bool are_weights_scalar(const std::shared_ptr& node); + +private: + static void extract(const ov::Input& input); }; } // namespace pass diff --git a/src/common/snippets/src/pass/common_optimizations.cpp b/src/common/snippets/src/pass/common_optimizations.cpp index da55629055a6e5..180207aa841cdd 100644 --- a/src/common/snippets/src/pass/common_optimizations.cpp +++ b/src/common/snippets/src/pass/common_optimizations.cpp @@ -4,27 +4,24 @@ #include "snippets/pass/common_optimizations.hpp" -#include -#include "openvino/opsets/opset1.hpp" -#include -#include "openvino/pass/pattern/op/wrap_type.hpp" - -#include "transformations/utils/utils.hpp" #include "snippets/pass/fq_decomposition.hpp" #include "snippets/pass/softmax_reshape_elimination.hpp" #include "snippets/pass/explicit_transpose_matmul_inputs.hpp" +#include "snippets/pass/transpose_decomposition.hpp" +#include "snippets/pass/fuse_transpose_brgemm.hpp" #include "snippets/op/subgraph.hpp" -#include "snippets/utils.hpp" #include "snippets/itt.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + namespace ov { namespace snippets { namespace pass { -// Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body -void ConvertConstantsToParameters(const std::shared_ptr& subgraph) { - OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ConvertConstantsToParameters"); +void CommonOptimizations::ExtractConstants(const std::shared_ptr& subgraph) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractConstants"); auto body = subgraph->body_ptr(); ParameterVector new_parameters; @@ -55,6 +52,52 @@ void ConvertConstantsToParameters(const std::shared_ptr& subgraph) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractUnsupportedTransposes"); + const auto& body = subgraph->body_ptr(); + const auto parameters = body->get_parameters(); + // [107806]: If count of Parameters isn't equal to Subgraph inputs, + // we cannot guarantee correct extraction since we don't have correct connections between body I/O and Subgraph I/O. + OPENVINO_ASSERT(parameters.size() == subgraph->input_values().size(), + "Failed to extract unsupported transposes: the count of Parameters isn't equal to Subgraph inputs"); + + bool updated = false; + for (size_t i = 0; i < parameters.size(); ++i) { + const auto& parameter = parameters[i]; + const auto& consumers = parameter->get_output_target_inputs(0); + if (consumers.size() != 1) + continue; + + const auto transpose = ov::as_type_ptr(consumers.begin()->get_node()->shared_from_this()); + if (!transpose) + continue; + + const auto& order = ov::as_type_ptr(transpose->get_input_node_shared_ptr(1)); + if (!order) + continue; + + const auto order_value = order->cast_vector(); + const auto transpose_child = *(transpose->get_output_target_inputs(0).begin()); + const auto is_brgemm_case = ov::is_type(transpose_child.get_node()->shared_from_this()); + // If Transpose is supported (can be decomposed or fused into Brgemm), skip + if ((is_brgemm_case && FuseTransposeBrgemm::supported_cases.count(order_value) != 0) || + (TransposeDecomposition::supported_cases.count(order_value) != 0)) + continue; + + // If the transpose isn't supported - we have to extract it from Subgraph + transpose->set_argument(0, subgraph->input_value(i)); + subgraph->set_argument(i, transpose); + transpose_child.replace_source_output(parameter); + // Update shape + parameter->set_partial_shape(transpose->get_output_partial_shape(0)); + updated = true; + } + + if (updated) { + subgraph->validate_and_infer_types(); + } +} + CommonOptimizations::CommonOptimizations() { MATCHER_SCOPE(CommonOptimizations); ov::graph_rewrite_callback callback = [this](ov::pass::pattern::Matcher& m) { @@ -65,10 +108,10 @@ CommonOptimizations::CommonOptimizations() { return false; } - auto body = subgraph->body_ptr(); + const auto& body = subgraph->body_ptr(); const auto is_quantized = subgraph->is_quantized(); - // Firsly we should transform all original Converts inside body to ConvertTruncation to save original behavior. + // Firstly, we should transform all original Converts inside body to ConvertTruncation to save original behavior. // Then if Subgraph contains FakeQuantize we enable specific transformation for quantized subgraphs. ov::pass::Manager manager; manager.register_pass(); @@ -80,15 +123,18 @@ CommonOptimizations::CommonOptimizations() { manager.run_passes(body); // At the moment only non-scalar Constants of FakeQuantize can be inside Subgraph - // so we can enable ConvertConstantsToParameters pass for quantized models + // so we can enable ExtractConstants pass for quantized models if (is_quantized) { - ConvertConstantsToParameters(subgraph); + ExtractConstants(subgraph); + } + // Extract unsupported Transposes from body + if (subgraph->has_domain_sensitive_ops()) { + ExtractUnsupportedTransposes(subgraph); } return true; }; - auto m = std::make_shared(ov::pass::pattern::wrap_type(), - matcher_name); + auto m = std::make_shared(ov::pass::pattern::wrap_type(), matcher_name); this->register_matcher(m, callback); } diff --git a/src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp b/src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp index 6948f6dfcf3476..e98a2c3d57a3fb 100644 --- a/src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp +++ b/src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp @@ -2,79 +2,101 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "snippets/itt.hpp" - #include "snippets/pass/explicit_transpose_matmul_inputs.hpp" -#include "snippets/pass/transpose_decomposition.hpp" + #include "snippets/op/subgraph.hpp" +#include "snippets/itt.hpp" -#include "openvino/core/rt_info.hpp" +#include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/core/rt_info.hpp" + +bool ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(const std::shared_ptr& node) { + const auto inputs = node->inputs(); + return std::all_of(inputs.begin() + 1, inputs.end(), + [](const ov::Input& in) { + return in.get_partial_shape().is_static() && ov::shape_size(in.get_shape()) == 1; + }); +} + +void ov::snippets::pass::ExplicitTransposeMatMulInputs::extract(const ov::Input& input) { + auto parent = input.get_source_output().get_node_shared_ptr(); + auto transpose = ov::as_type_ptr(parent); + while (!transpose && !ov::is_type(parent)) { + // We can set supported order and transposed_=false only if ops have scalar shapes to avoid shape mismatching + if (!are_weights_scalar(parent)) + break; + + parent = parent->get_input_node_shared_ptr(0); + transpose = ov::as_type_ptr(parent); + } + + // If there isn't another Transpose, need to create new Transpose + if (transpose) { + const auto transpose_pattern = ov::as_type_ptr(transpose->get_input_node_shared_ptr(1)); + OPENVINO_ASSERT(transpose_pattern, + "ExplicitTransposeMatMulInputs expects existing Transpose with Constant order"); + + auto transposed_order = transpose_pattern->cast_vector(); + OPENVINO_ASSERT(transposed_order.size() > 2, "Incorrect Transpose order for ExplicitTransposeMatMulInputs"); + std::swap(*transposed_order.rbegin(), *(transposed_order.rbegin() + 1)); + + auto new_transpose_order = std::make_shared(transpose_pattern->get_element_type(), + ov::Shape{transposed_order.size()}, + transposed_order); + new_transpose_order->set_friendly_name(transpose_pattern->get_friendly_name()); + ov::copy_runtime_info(transpose_pattern, new_transpose_order); + transpose->set_argument(1, new_transpose_order); + return; + } + + // Create new Transpose before Parameter + OPENVINO_ASSERT(ov::is_type(parent), + "ExplicitTransposeMatMulInputs expects Parameter in cases when there isn't existing Transpose on input"); + const auto& consumers = parent->get_output_target_inputs(0); + OPENVINO_ASSERT(consumers.size() == 1, + "ExplicitTransposeMatMulInputs expects Parameter with one consumer in cases when there isn't existing Transpose on input"); + // Extract Transpose from MatMul + OPENVINO_ASSERT(input.get_partial_shape().is_static(), "ExplicitTransposeMatMulInputs supports only static shapes"); + const auto rank = input.get_shape().size(); + std::vector transpose_order(rank, 0); + std::iota(transpose_order.begin(), transpose_order.end(), 0); + std::swap(transpose_order[rank - 1], transpose_order[rank - 2]); + + const auto constant_order = std::make_shared(ov::element::i32, ov::Shape{rank}, transpose_order); + const auto new_transpose = std::make_shared(parent, constant_order); // parent is Parameter + const auto consumer_input = *(consumers.begin()); + consumer_input.replace_source_output(new_transpose); +} ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs() { MATCHER_SCOPE(ExplicitTransposeMatMulInputs); - auto m_matmul0 = std::make_shared( + auto m_matmul0 = std::make_shared( ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()), ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape())); register_matcher(std::make_shared(m_matmul0, matcher_name), [=](ov::pass::pattern::Matcher &m) { - OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs") - auto root = m.get_match_root(); - bool rewritten = false; - - auto matmul0 = ov::as_type_ptr(root); - if (!matmul0) - return false; - - for (size_t i = 0; i < matmul0->get_input_size(); i++) { - if (i == 0 && !matmul0->get_transpose_a()) - continue; - if (i == 1 && !matmul0->get_transpose_b()) - continue; - - auto parent1 = matmul0->get_input_node_shared_ptr(i); - auto transpose1 = ov::as_type_ptr(parent1); - while (!transpose1 && !ov::is_type(parent1)) { - // We can set supported order and transposed_b(false) only if ops have scalar shapes to avoid shape mismatching - const auto parent_count = parent1->inputs().size(); - bool are_weights_scalar = true; - for (size_t j = 1; j < parent_count; ++j) { - are_weights_scalar = are_weights_scalar && ov::shape_size(parent1->get_input_shape(j)) == 1; - } - if (!are_weights_scalar) - break; - - parent1 = parent1->get_input_node_shared_ptr(0); - transpose1 = ov::as_type_ptr(parent1); + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs") + auto root = m.get_match_root(); + bool rewritten = false; + + auto matmul = ov::as_type_ptr(root); + if (!matmul) + return false; + + if (matmul->get_transpose_a()) { + extract(matmul->input(0)); + matmul->set_transpose_a(false); + rewritten |= true; } - if (!transpose1) - continue; - - const auto transpose_pattern = ov::as_type_ptr(transpose1->get_input_node_shared_ptr(1)); - if (!transpose_pattern) - continue; - - auto transposed_order = transpose_pattern->cast_vector(); - std::swap(*transposed_order.rbegin(), *(transposed_order.rbegin() + 1)); - if (pass::TransposeDecomposition::supported_cases.count(transposed_order) == 0) - continue; - - auto new_transpose_order = std::make_shared(transpose_pattern->get_element_type(), - ov::Shape{4}, - transposed_order); - new_transpose_order->set_friendly_name(transpose_pattern->get_friendly_name()); - ov::copy_runtime_info(transpose_pattern, new_transpose_order); - transpose1->set_argument(1, new_transpose_order); - if (i == 0) { - matmul0->set_transpose_a(false); - } else { - matmul0->set_transpose_b(false); + if (matmul->get_transpose_b()) { + extract(matmul->input(1)); + matmul->set_transpose_b(false); + rewritten |= true; } - rewritten |= true; - } - return rewritten; - }); + return rewritten; + }); } diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index 864341dc417e53..1f62781c69e909 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -7,6 +7,7 @@ #include "snippets/itt.hpp" #include "snippets/pass/collapse_subgraph.hpp" +#include "snippets/pass/explicit_transpose_matmul_inputs.hpp" #include "snippets/op/subgraph.hpp" #include "snippets/op/brgemm.hpp" #include "snippets/utils.hpp" @@ -156,12 +157,7 @@ auto update_intermediate_supported_ops(std::shared_ptr& interm_op, ov: break; // Add node only if there are scalar constants on inputs because of plugin-specific limitation - bool are_weights_scalar = true; - const auto parent_count = parent->get_input_size(); - for (size_t i = 1; i < parent_count; ++i) { - are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1; - } - if (!are_weights_scalar) + if (!ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(parent)) break; ordered_ops.insert(ordered_ops.begin() + shift, parent); @@ -321,22 +317,27 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken /***** Transposes *****/ /* There may be Transpose and Reshape ops on inputs and outputs of MHA-pattern skeleton * We can add them into Subgraph body + * Transpose0 Transpose1 + * \ / + * MatMul0 + * | + * [...] Transpose2 + * \ / + * MatMul1 + * | + * Transpose3 */ - auto tokenize_transpose = [config](const std::shared_ptr& node) -> std::shared_ptr { - return config.mha_token_enable_transpose ? ov::as_type_ptr(node) - : nullptr; - }; - // First input branch of MatMul0 should be executed before second input branch of MatMul0, - // so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose1 - bool are_weights_scalar = true; + // so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose0 + // Note: If MatMul0 has transposed_b, we should tokenize only scalars ops from 1st branch + // to move extracted Transpose from MatMul input to body Parameter + auto parent = matmul0->get_input_node_shared_ptr(1); // We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order (or without this Transpose1) // only if these ops have scalar shapes on other inputs. // There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false). // We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching const auto is_transposed_b_0 = matmul0->get_transpose_b(); - auto parent = matmul0->get_input_node_shared_ptr(1); while (is_supported_intermediate_op(parent)) { // All supported ops have only one output port if (parent->get_output_target_inputs(0).size() != 1) @@ -344,15 +345,8 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken // Only if MatMul0 has transposed_b, we have to tokenize scalar ops // to move explicit Transpose from MatMul0 input_1 to Parameter of Subgraph body - if (is_transposed_b_0) { - const auto parent_count = parent->get_input_size(); - bool are_weights_scalar = true; - for (size_t i = 1; i < parent_count; ++i) { - are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1; - } - if (!are_weights_scalar) { - break; - } + if (is_transposed_b_0 && !ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(parent)) { + break; } // To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation) @@ -360,53 +354,45 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken if (const auto fq_node = ov::as_type_ptr(parent)) { hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node); } + potential_body_params_count += get_potential_body_params(parent); ordered_ops.insert(ordered_ops.begin(), parent); - // TODO [107731] To go always through 0-th port - is it safe? + // [107731] To go always through 0-th port - is it safe? parent = parent->get_input_node_shared_ptr(0); } - const auto transpose1 = tokenize_transpose(parent); - if (is_transposed_b_0) { - if (is_valid_transpose(transpose1, {0, 2, 1, 3})) { - // We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order - // only if these ops have scalar shapes on other inputs. - // There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false). - // We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching - if (are_weights_scalar) { - ordered_ops.insert(ordered_ops.begin(), transpose1); - } else { - return false; + auto tokenize_transpose = [&](const std::shared_ptr& transpose, + bool is_input_transposed, std::vector order, + const ov::NodeVector::const_iterator& pos) { + // If Transpose has valid order for the Transpose fusing (ExplicitTransposeMatMulInputs pass call), tokenize him. + // Otherwise, skip the Transpose. + if (!is_input_transposed) { + if (is_valid_transpose(transpose, order)) { + ordered_ops.insert(pos, transpose); } - } else { - return false; + return; } - } else { - if (is_valid_transpose(transpose1, {0, 2, 3, 1})) { - ordered_ops.insert(ordered_ops.begin(), transpose1); + auto transposed_order = order; + const auto rank = transposed_order.size(); + if (rank < 2) + return; + std::swap(transposed_order[rank - 1], transposed_order[rank - 2]); + if (is_valid_transpose(transpose, transposed_order)) { + ordered_ops.insert(pos, transpose); } - } - - if (transpose1) { - // Between Transpose1 and MatMul0 will be the one Loop because of LoopFusing optimization. - // The Loop will have one Buffer with the same shape both on input and output. - // Need to check for precision to get if we need one more register for Buffer - if (matmul0->get_input_element_type(1).size() != transpose1->get_output_element_type(0).size()) { - buffer_count++; - } - } + }; - const auto transpose0 = tokenize_transpose(matmul0->get_input_node_shared_ptr(0)); - if (is_valid_transpose(transpose0, {0, 2, 1, 3})) { - ordered_ops.insert(ordered_ops.begin(), transpose0); - } else if (matmul0->get_transpose_a()) { - return false; - } + auto get_transpose = [config](const std::shared_ptr& node) -> std::shared_ptr { + return config.mha_token_enable_transpose ? ov::as_type_ptr(node) + : nullptr; + }; - const auto transpose2 = tokenize_transpose(matmul1->get_input_node_shared_ptr(1)); - if (is_valid_transpose(transpose2, {0, 2, 1, 3})) { - ordered_ops.push_back(transpose2); - } + const auto transpose1 = get_transpose(parent); + const auto transpose0 = get_transpose(matmul0->get_input_node_shared_ptr(0)); + const auto transpose2 = get_transpose(matmul1->get_input_node_shared_ptr(1)); + tokenize_transpose(transpose1, is_transposed_b_0, {0, 2, 3, 1}, ordered_ops.begin()); + tokenize_transpose(transpose0, matmul0->get_transpose_a(), {0, 2, 1, 3}, ordered_ops.begin()); + tokenize_transpose(transpose2, matmul1->get_transpose_b(), {0, 2, 1, 3}, ordered_ops.end()); ordered_ops.push_back(matmul1); bool are_ops_after_matmul1 = false; @@ -439,7 +425,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken // // Transpose3 if (!are_ops_after_matmul1) { - auto transpose3 = tokenize_transpose(child); + auto transpose3 = get_transpose(child); if (is_valid_transpose(transpose3, {0, 2, 1, 3}) && transpose3->get_input_element_type(0) == matmul1_out_type) { // To avoid Convert between MatMul1 and Transpose3 ordered_ops.push_back(transpose3); diff --git a/src/common/snippets/tests/src/pass/mha_tokenization.cpp b/src/common/snippets/tests/src/pass/mha_tokenization.cpp index 68956a2a626105..86e211d5c1dbb1 100644 --- a/src/common/snippets/tests/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/tests/src/pass/mha_tokenization.cpp @@ -7,7 +7,7 @@ #include #include "snippets/pass/tokenization.hpp" #include "snippets/pass/mha_tokenization.hpp" -#include "snippets/pass/explicit_transpose_matmul_inputs.hpp" +#include "snippets/pass/common_optimizations.hpp" namespace ov { namespace test { @@ -15,9 +15,10 @@ namespace snippets { void TokenizeMHASnippetsTests::run() { ASSERT_TRUE(function); - std::string name; manager.register_pass(); manager.register_pass(); + manager.register_pass(); + disable_rt_info_check(); } TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA) { @@ -43,6 +44,31 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_int_Matmuls) { run(); } +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction) { + const auto& f = MHATransposedInputFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 128, 12, 64}}, true); + function = f.getOriginal(); + function_ref = f.getReference(); + run(); +} + +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction_and_unsupported_existing_transpose) { + const auto& f = MHATransposedInputFunction(std::vector{{1, 128, 12, 64}, {1, 12, 64, 128}, {1, 128, 12, 64}}, true, + std::vector{0, 3, 1, 2}); + function = f.getOriginal(); + function_ref = f.getReference(); + run(); +} + +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_fusion) { + const auto& f = MHATransposedInputFunction(std::vector{{1, 128, 12, 64}, {1, 64, 128, 12}, {1, 128, 12, 64}}, false, + std::vector{0, 2, 1, 3}); + function = f.getOriginal(); + function_ref = f.getReference(); + run(); +} + + + } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index c210c1bb6ae447..6b3f525f719988 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -227,6 +227,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQ, MHAFQ, ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), MHA::getTestCaseName); +const std::vector> inputShapesTransposedB = { + {{1, 12, 12, 64}, {1, 12, 48, 64}, {1, 12, 48, 64}} +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHATransposedB, MHATransposedB, + ::testing::Combine( + ::testing::ValuesIn(inputShapesTransposedB), + ::testing::Values(std::vector{}), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(std::map{})), + MHA::getTestCaseName); } // namespace } // namespace snippets diff --git a/src/tests/functional/plugin/shared/include/snippets/mha.hpp b/src/tests/functional/plugin/shared/include/snippets/mha.hpp index dde9394869fd09..2f9b950c798273 100644 --- a/src/tests/functional/plugin/shared/include/snippets/mha.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/mha.hpp @@ -5,6 +5,7 @@ #pragma once #include "shared_test_classes/base/snippets_test_utils.hpp" +#include "ngraph_helpers/snippets_ngraph_functions/include/snippets_helpers.hpp" namespace ov { namespace test { @@ -30,7 +31,7 @@ class MHA : public testing::WithParamInterface, void SetUp() override; void generate_inputs(const std::vector& targetInputStaticShapes) override; - virtual void init_subgraph(); + virtual std::shared_ptr get_subgraph(); bool m_with_mul = false; std::vector m_input_types; @@ -39,39 +40,42 @@ class MHA : public testing::WithParamInterface, class MHASelect : public MHA { protected: void generate_inputs(const std::vector& targetInputStaticShapes) override; - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; class MHAWOTransposeOnInputs : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; class MHAWOTranspose : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; +}; + +class MHAMulAdd : public MHA { + std::shared_ptr get_subgraph() override; +}; + +class MHATransposedB : public MHA { + std::shared_ptr get_subgraph() override; }; class MHAINT8MatMul : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; class MHAFQAfterMatMul : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; class MHAFQ : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; -class MHAMulAdd : public MHA { - void init_subgraph() override; -}; - - } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 9e8bd6c4d79685..d5924f2cb58e6c 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -51,7 +51,8 @@ void MHA::SetUp() { std::tie(inputShapes, m_input_types, prc, m_with_mul, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); - init_subgraph(); + const auto subgraph_model = get_subgraph(); + function = subgraph_model->getOriginal(); configuration.insert(additionalConfig.begin(), additionalConfig.end()); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { @@ -76,9 +77,8 @@ void MHA::generate_inputs(const std::vector& targetInputStaticSha } } -void MHA::init_subgraph() { - auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_input_types, m_with_mul); - function = f.getOriginal(); +std::shared_ptr MHA::get_subgraph() { + return std::make_shared(inputDynamicShapes, m_input_types, m_with_mul); } void MHASelect::generate_inputs(const std::vector& targetInputStaticShapes) { @@ -99,39 +99,36 @@ void MHASelect::generate_inputs(const std::vector& targetInputSta } } -void MHASelect::init_subgraph() { - auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes, m_input_types); - function = f.getOriginal(); +std::shared_ptr MHASelect::get_subgraph() { + return std::make_shared(inputDynamicShapes, m_input_types); } -void MHAWOTransposeOnInputs::init_subgraph() { - auto f = ov::test::snippets::MHAWOTransposeOnInputsFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAWOTransposeOnInputs::get_subgraph() { + return std::make_shared(inputDynamicShapes); } -void MHAWOTranspose::init_subgraph() { - auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes, m_input_types); - function = f.getOriginal(); +std::shared_ptr MHAWOTranspose::get_subgraph() { + return std::make_shared(inputDynamicShapes, m_input_types); } -void MHAINT8MatMul::init_subgraph() { - auto f = ov::test::snippets::MHAINT8MatMulFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAINT8MatMul::get_subgraph() { + return std::make_shared(inputDynamicShapes); } -void MHAFQAfterMatMul::init_subgraph() { - auto f = ov::test::snippets::MHAFQAfterMatMulFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAFQAfterMatMul::get_subgraph() { + return std::make_shared(inputDynamicShapes); } -void MHAFQ::init_subgraph() { - auto f = ov::test::snippets::MHAFQFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAFQ::get_subgraph() { + return std::make_shared(inputDynamicShapes); } -void MHAMulAdd::init_subgraph() { - auto f = ov::test::snippets::MHAMulAddFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAMulAdd::get_subgraph() { + return std::make_shared(inputDynamicShapes); +} + +std::shared_ptr MHATransposedB::get_subgraph() { + return std::make_shared(inputDynamicShapes, true); } TEST_P(MHA, CompareWithRefImpl) { @@ -153,26 +150,37 @@ TEST_P(MHAWOTransposeOnInputs, CompareWithRefImpl) { } TEST_P(MHAWOTranspose, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } TEST_P(MHAMulAdd, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(MHATransposedB, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } TEST_P(MHAINT8MatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } TEST_P(MHAFQAfterMatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } TEST_P(MHAFQ, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp index 745d5e990f3b66..f339290209409d 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp @@ -289,6 +289,31 @@ class MHAMulAddFunction : public SnippetsFunctionBase { std::shared_ptr initOriginal() const override; }; +/* Graph: + * Transpose/Parameter + * \ / + * MatMul0 [transposed_b = true/false] + * | + * Softmax + * \ / + * MatMul1 + * | + */ +class MHATransposedInputFunction : public SnippetsFunctionBase { +public: + explicit MHATransposedInputFunction(const std::vector& inputShapes, bool transposed_b = false, + std::vector order = {}) + : SnippetsFunctionBase(inputShapes), m_transposed_b(transposed_b), m_order(order) { + NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; + std::shared_ptr initReference() const override; + + bool m_transposed_b = false; + std::vector m_order = {}; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp index fdd9fd3a9c1f37..440e3607a2fabe 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp @@ -685,6 +685,72 @@ std::shared_ptr MHAMulAddFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } +std::shared_ptr MHATransposedInputFunction::initOriginal() const { + const auto param0 = std::make_shared(precision, input_shapes[0]); + const auto param1 = std::make_shared(precision, input_shapes[1]); + const auto param2 = std::make_shared(precision, input_shapes[2]); + ngraph::ParameterVector ngraphParam = {param0, param1, param2}; + + std::shared_ptr matmul0_in1 = param1; + if (!m_order.empty()) { + const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{m_order.size()}, m_order); + matmul0_in1 = std::make_shared(param1, transposeConst); + } + + const auto matMul0 = std::make_shared(param0, matmul0_in1, false, m_transposed_b); + const auto softmax = std::make_shared(matMul0, -1); + const auto matMul1 = std::make_shared(softmax, param2); + + ngraph::ResultVector results{std::make_shared(matMul1)}; + return std::make_shared(results, ngraphParam, "mha"); +} + +std::shared_ptr MHATransposedInputFunction::initReference() const { + const auto data0 = std::make_shared(precision, input_shapes[0]); + const auto data1 = std::make_shared(precision, input_shapes[1]); + const auto data2 = std::make_shared(precision, input_shapes[2]); + ngraph::ParameterVector ngraphParam = {data0, data1, data2}; + + bool is_supported = ((m_transposed_b && m_order == std::vector{0, 2, 1, 3}) || + (!m_transposed_b && m_order == std::vector{0, 2, 3, 1})); + + std::shared_ptr in1 = data1; + if (!m_order.empty() && !is_supported) { + const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{m_order.size()}, m_order); + in1 = std::make_shared(in1, transposeConst); + } + if (m_transposed_b) { + if (m_order != std::vector{0, 2, 1, 3}) { + const auto rank = input_shapes[1].size(); + std::vector transpose_order(rank, 0); + std::iota(transpose_order.begin(), transpose_order.end(), 0); + std::swap(transpose_order[rank - 1], transpose_order[rank - 2]); + const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i32, ov::Shape{transpose_order.size()}, transpose_order); + in1 = std::make_shared(in1, transposeConst); + } + } + + const auto param0 = std::make_shared(precision, data0->get_shape()); + const auto param1 = std::make_shared(precision, in1->get_shape()); + const auto param2 = std::make_shared(precision, data2->get_shape()); + + std::shared_ptr matmul0_in1 = param1; + if (!m_order.empty() && is_supported) { + const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i32, ov::Shape{m_order.size()}, m_order); + matmul0_in1 = std::make_shared(param1, transposeConst); + } + + const auto matMul0 = std::make_shared(param0, matmul0_in1); + const auto softmax = std::make_shared(matMul0, -1); + const auto matMul1 = std::make_shared(softmax, param2); + + auto subgraph = std::make_shared(ov::NodeVector{data0, in1, data2}, + std::make_shared(NodeVector{matMul1}, ov::ParameterVector{param0, param1, param2})); + + ngraph::ResultVector results{std::make_shared(subgraph)}; + return std::make_shared(results, ngraphParam, "mha"); +} + } // namespace snippets } // namespace test } // namespace ov From 008d9a83cecbd6414587158b298c439ea9fd1e54 Mon Sep 17 00:00:00 2001 From: Maciej Smyk Date: Wed, 14 Jun 2023 10:45:47 +0200 Subject: [PATCH 04/11] Update build_linux.md (#18035) --- docs/dev/build_linux.md | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/dev/build_linux.md b/docs/dev/build_linux.md index 6ffda5114a1633..1dc58ba157eb1c 100644 --- a/docs/dev/build_linux.md +++ b/docs/dev/build_linux.md @@ -12,7 +12,13 @@ The software was validated on: - [CMake](https://cmake.org/download/) 3.13 or higher - GCC 7.5 or higher to build OpenVINO Runtime - Python 3.7 - 3.11 for OpenVINO Runtime Python API -- (Optional) [Install Intel® Graphics Compute Runtime for OpenCL™ Driver package 23.13.26032.30](https://github.com/intel/compute-runtime/releases/tag/23.13.26032.30) to enable inference on Intel integrated GPUs. +- (Optional) Install Intel® Graphics Compute Runtime for OpenCL™ Driver package to enable inference on Intel integrated GPUs. Select a driver package from the table below depending on what version of Ubuntu you are installing on. + + | Ubuntu | Driver package | + | --- | ----------- | + | 22.04 | [23.13.26032.30](https://github.com/intel/compute-runtime/releases/tag/23.13.26032.30) | + | 20.04 | [22.24.23453](https://github.com/intel/compute-runtime/releases/tag/22.24.23453) | + | 18.04 | [21.38.21026](https://github.com/intel/compute-runtime/releases/tag/21.38.21026) | ## How to build @@ -36,17 +42,17 @@ The software was validated on: ```sh sudo ./install_build_dependencies.sh ``` - > **NOTE**: By default, the build enables the OpenVINO Runtime GPU plugin to infer models on your Intel® Processor Graphics. This requires you to [Install Intel® Graphics Compute Runtime for OpenCL™ Driver package 23.13.26032.30](https://github.com/intel/compute-runtime/releases/tag/23.13.26032.30) before running the build. If you don't want to use the GPU plugin, use the `-DENABLE_INTEL_GPU=OFF` CMake build option and skip the installation of the Intel® Graphics Compute Runtime for OpenCL™ Driver. 3. Create a build folder: -```sh - mkdir build && cd build -``` + ```sh + mkdir build && cd build + ``` + 4. OpenVINO Runtime uses a CMake-based build system. In the created `build` directory, run `cmake` to fetch project dependencies and create Unix makefiles, then run `make` to build the project: -```sh - cmake -DCMAKE_BUILD_TYPE=Release .. - make --jobs=$(nproc --all) -``` + ```sh + cmake -DCMAKE_BUILD_TYPE=Release .. + make --jobs=$(nproc --all) + ``` The process may take some time to finish. ### Additional Build Options @@ -59,6 +65,8 @@ You can use the following additional build options: cmake -DCMAKE_TOOLCHAIN_FILE=/cmake/toolchains/ia32.linux.toolchain.cmake .. ``` +- If you don't want to use the GPU plugin, use the `-DENABLE_INTEL_GPU=OFF` CMake build option and skip the installation of the Intel® Graphics Compute Runtime for OpenCL™ Driver. + - To build the OpenVINO Runtime Python API: 1. Install all additional packages (e.g., cython and opencv) listed in the `/src/bindings/python/src/compatibility/openvino/requirements-dev.txt` file: ```sh From 63a5ec5762cb403aa55da560ca7695ea72a248f1 Mon Sep 17 00:00:00 2001 From: Roman Lyamin Date: Wed, 14 Jun 2023 14:33:58 +0400 Subject: [PATCH 05/11] [GPU] Several fixes for format traits (#18018) --- .../include/intel_gpu/runtime/format.hpp | 2 +- src/plugins/intel_gpu/src/runtime/format.cpp | 166 +++++++++--------- 2 files changed, 84 insertions(+), 84 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/format.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/format.hpp index e73ea6604ff71b..5403d89ee1f56f 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/format.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/format.hpp @@ -46,7 +46,7 @@ struct format_traits { /// @brief Characters representing feature map/channel dimensions in an order. static const char* feature_chars() { return "fic"; } /// @brief Characters representing spatial dimensions in an order. - static const char* spatial_chars() { return "xyzhsw"; } + static const char* spatial_chars() { return "xyzwuvhs"; } /// @brief Characters representing group dimensions in an order. static const char* group_chars() { return "g"; } /// @brief Checks if @p c represents batch dimension. diff --git a/src/plugins/intel_gpu/src/runtime/format.cpp b/src/plugins/intel_gpu/src/runtime/format.cpp index cc72537cd6ed57..1f9f127ba37ec2 100644 --- a/src/plugins/intel_gpu/src/runtime/format.cpp +++ b/src/plugins/intel_gpu/src/runtime/format.cpp @@ -21,54 +21,54 @@ static const std::map format_traits_map { // Order - dims changing order from rare to often // Inner order - dims order for internal storage in _sizes array // Block sizes - vector of pairs of dimension number (by inner order) and block size ordered from rare to often - // Format B F S G Dims order Order Inner order Block sizes - FMT_TRAITS(yxfb, 1, 1, 2, 0, {2, 3, 1, 0}, "yxfb", "bfxy?", {}), - FMT_TRAITS(byxf, 1, 1, 2, 0, {0, 2, 3, 1}, "byxf", "bfxy?", {}), - FMT_TRAITS(bfyx, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {}), - FMT_TRAITS(fyxb, 1, 1, 2, 0, {1, 2, 3, 0}, "fyxb", "bfxy?", {}), - FMT_TRAITS(byfx, 1, 1, 2, 0, {0, 2, 1, 3}, "byfx", "bfxy?", {}), - FMT_TRAITS(bxfy, 1, 1, 2, 0, {0, 3, 1, 2}, "bxfy", "bfxy?", {}), - FMT_TRAITS(b_fs_yx_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{1, 2}}), - FMT_TRAITS(b_fs_yx_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{1, 4}}), - FMT_TRAITS(b_fs_yx_fsv16, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 16}}), - FMT_TRAITS(b_fs_yx_fsv32, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 32}}), - FMT_TRAITS(b_fs_zyx_fsv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 2}}), - FMT_TRAITS(b_fs_zyx_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 4}}), - FMT_TRAITS(b_fs_zyx_fsv32, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 32}}), - FMT_TRAITS(bs_fs_fsv8_bsv8, 1, 1, 0, 0, {0, 1}, "bf", "bf??", {{0, 8}, {1, 8}}), - FMT_TRAITS(bs_fs_fsv8_bsv16, 1, 1, 0, 0, {0, 1}, "bf", "bf??", {{0, 16}, {1, 8}}), - FMT_TRAITS(bs_f_bsv16, 1, 1, 0, 0, {0, 1}, "bf", "bf??", {{0, 16}}), - FMT_TRAITS(winograd_2x3_s1_data, 1, 1, 2, 0, {0, 2, 3, 1}, "bxyf", "bfxy?", {}), - FMT_TRAITS(bzyxf, 1, 1, 3, 0, {0, 2, 3, 4, 1}, "bzyxf", "bfxyz", {}), - FMT_TRAITS(bfzyx, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {}), - FMT_TRAITS(bfwzyx, 1, 1, 4, 0, {0, 1, 2, 3, 4, 5}, "bfwzyx", "bfxyzw", {}), - FMT_TRAITS(bfuwzyx, 1, 1, 5, 0, {0, 1, 2, 3, 4, 5, 6}, "bfuwzyx", "bfxyzwu", {}), + // Format B F S G Dims order Order Inner order Block sizes + FMT_TRAITS(yxfb, 1, 1, 2, 0, {2, 3, 1, 0}, "yxfb", "bfxy", {}), + FMT_TRAITS(byxf, 1, 1, 2, 0, {0, 2, 3, 1}, "byxf", "bfxy", {}), + FMT_TRAITS(bfyx, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {}), + FMT_TRAITS(fyxb, 1, 1, 2, 0, {1, 2, 3, 0}, "fyxb", "bfxy", {}), + FMT_TRAITS(byfx, 1, 1, 2, 0, {0, 2, 1, 3}, "byfx", "bfxy", {}), + FMT_TRAITS(bxfy, 1, 1, 2, 0, {0, 3, 1, 2}, "bxfy", "bfxy", {}), + FMT_TRAITS(b_fs_yx_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 2}}), + FMT_TRAITS(b_fs_yx_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 4}}), + FMT_TRAITS(b_fs_yx_fsv16, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 16}}), + FMT_TRAITS(b_fs_yx_fsv32, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 32}}), + FMT_TRAITS(b_fs_zyx_fsv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 2}}), + FMT_TRAITS(b_fs_zyx_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 4}}), + FMT_TRAITS(b_fs_zyx_fsv32, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 32}}), + FMT_TRAITS(bs_fs_fsv8_bsv8, 1, 1, 0, 0, {0, 1}, "bf", "bf", {{0, 8}, {1, 8}}), + FMT_TRAITS(bs_fs_fsv8_bsv16, 1, 1, 0, 0, {0, 1}, "bf", "bf", {{0, 16}, {1, 8}}), + FMT_TRAITS(bs_f_bsv16, 1, 1, 0, 0, {0, 1}, "bf", "bf", {{0, 16}}), + FMT_TRAITS(winograd_2x3_s1_data, 1, 1, 2, 0, {0, 2, 3, 1}, "bxyf", "bfxy", {}), + FMT_TRAITS(bzyxf, 1, 1, 3, 0, {0, 2, 3, 4, 1}, "bzyxf", "bfxyz", {}), + FMT_TRAITS(bfzyx, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {}), + FMT_TRAITS(bfwzyx, 1, 1, 4, 0, {0, 1, 2, 3, 4, 5}, "bfwzyx", "bfxyzw", {}), + FMT_TRAITS(bfuwzyx, 1, 1, 5, 0, {0, 1, 2, 3, 4, 5, 6}, "bfuwzyx", "bfxyzwu", {}), FMT_TRAITS(bfvuwzyx, 1, 1, 6, 0, {0, 1, 2, 3, 4, 5, 6, 7}, "bfvuwzyx", "bfxyzwuv", {}), - FMT_TRAITS(fs_b_yx_fsv32, 1, 1, 2, 0, {1, 0, 2, 3}, "fbyx", "bfxy?", {{1, 32}}), - FMT_TRAITS(b_fs_yx_32fp, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {}), - FMT_TRAITS(b_fs_zyx_fsv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 16}}), - FMT_TRAITS(bs_fs_zyx_bsv16_fsv32, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 32}}), - FMT_TRAITS(bs_fs_zyx_bsv16_fsv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 16}}), - FMT_TRAITS(bs_fs_yx_bsv16_fsv16, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 16 }, {1, 16}}), - FMT_TRAITS(bs_fs_yx_bsv16_fsv32, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 16 }, {1, 32}}), - FMT_TRAITS(bs_fs_yx_bsv4_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 4 }, {1, 4}}), - FMT_TRAITS(bs_fs_yx_bsv8_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 8 }, {1, 4}}), - FMT_TRAITS(bs_fs_zyx_bsv8_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 8 }, {1, 4}}), - FMT_TRAITS(bs_fs_yx_bsv16_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 16 }, {1, 4}}), - FMT_TRAITS(bs_fs_zyx_bsv16_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 4}}), - FMT_TRAITS(bs_fs_yx_bsv16_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 16 }, {1, 2}}), - FMT_TRAITS(bs_fs_zyx_bsv16_fsv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 2}}), - FMT_TRAITS(bs_fs_yx_bsv8_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 8 }, {1, 2}}), - FMT_TRAITS(bs_fs_zyx_bsv8_fsv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 8 }, {1, 2}}), - FMT_TRAITS(bs_fs_yx_bsv4_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 4 }, {1, 2}}), - FMT_TRAITS(bs_fs_zyx_bsv4_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 4 }, {1, 4}}), - FMT_TRAITS(bs_fs_zyx_bsv4_fsv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 4 }, {1, 2}}), - FMT_TRAITS(bs_fs_zyx_bsv32_fsv32, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 32 }, {1, 32}}), - FMT_TRAITS(bs_fs_zyx_bsv32_fsv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 32 }, {1, 16}}), - FMT_TRAITS(bs_fs_yx_bsv32_fsv32, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 32 }, {1, 32}}), - FMT_TRAITS(bs_fs_yx_bsv32_fsv16, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{0, 32 }, {1, 16}}), - FMT_TRAITS(nv12, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {}), - FMT_TRAITS(image_2d_rgba, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {}), + FMT_TRAITS(fs_b_yx_fsv32, 1, 1, 2, 0, {1, 0, 2, 3}, "fbyx", "bfxy", {{1, 32}}), + FMT_TRAITS(b_fs_yx_32fp, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {}), + FMT_TRAITS(b_fs_zyx_fsv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 16}}), + FMT_TRAITS(bs_fs_zyx_bsv16_fsv32, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 32}}), + FMT_TRAITS(bs_fs_zyx_bsv16_fsv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 16}}), + FMT_TRAITS(bs_fs_yx_bsv16_fsv16, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 16 }, {1, 16}}), + FMT_TRAITS(bs_fs_yx_bsv16_fsv32, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 16 }, {1, 32}}), + FMT_TRAITS(bs_fs_yx_bsv4_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 4 }, {1, 4}}), + FMT_TRAITS(bs_fs_yx_bsv8_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 8 }, {1, 4}}), + FMT_TRAITS(bs_fs_zyx_bsv8_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 8 }, {1, 4}}), + FMT_TRAITS(bs_fs_yx_bsv16_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 16 }, {1, 4}}), + FMT_TRAITS(bs_fs_zyx_bsv16_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 4}}), + FMT_TRAITS(bs_fs_yx_bsv16_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 16 }, {1, 2}}), + FMT_TRAITS(bs_fs_zyx_bsv16_fsv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 2}}), + FMT_TRAITS(bs_fs_yx_bsv8_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 8 }, {1, 2}}), + FMT_TRAITS(bs_fs_zyx_bsv8_fsv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 8 }, {1, 2}}), + FMT_TRAITS(bs_fs_yx_bsv4_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 4 }, {1, 2}}), + FMT_TRAITS(bs_fs_zyx_bsv4_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 4 }, {1, 4}}), + FMT_TRAITS(bs_fs_zyx_bsv4_fsv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 4 }, {1, 2}}), + FMT_TRAITS(bs_fs_zyx_bsv32_fsv32, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 32 }, {1, 32}}), + FMT_TRAITS(bs_fs_zyx_bsv32_fsv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 32 }, {1, 16}}), + FMT_TRAITS(bs_fs_yx_bsv32_fsv32, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 32 }, {1, 32}}), + FMT_TRAITS(bs_fs_yx_bsv32_fsv16, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 32 }, {1, 16}}), + FMT_TRAITS(nv12, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {}), + FMT_TRAITS(image_2d_rgba, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {}), FMT_TRAITS(oiyx, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {}), FMT_TRAITS(ioyx, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {}), @@ -76,27 +76,27 @@ static const std::map format_traits_map { FMT_TRAITS(oyxi, 1, 1, 2, 0, {0, 2, 3, 1}, "oyxi", "oixy", {}), FMT_TRAITS(oyix, 1, 1, 2, 0, {0, 2, 1, 3}, "oyix", "oixy", {}), FMT_TRAITS(oxiy, 1, 1, 2, 0, {0, 3, 1, 2}, "oxiy", "oixy", {}), - FMT_TRAITS(yxio, 1, 1, 2, 0, {2, 3, 1, 0}, "yxio", "oixy?", {}), + FMT_TRAITS(yxio, 1, 1, 2, 0, {2, 3, 1, 0}, "yxio", "oixy", {}), FMT_TRAITS(oizyx, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {}), FMT_TRAITS(iozyx, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {}), FMT_TRAITS(os_is_yx_isv16_osv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 16}, {0, 16}}), - FMT_TRAITS(o_is_yx_isv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{1, 16}}), - FMT_TRAITS(os_yxi_osv16, 1, 1, 2, 0, {0, 2, 3, 1}, "oyxi", "oixy?", {{0, 16}}), - FMT_TRAITS(os_iyx_osv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 16}}), - FMT_TRAITS(os_iyx_osv32, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 32}}), - FMT_TRAITS(os_iyx_osv64, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 64}}), - FMT_TRAITS(winograd_2x3_s1_weights, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}), - FMT_TRAITS(winograd_2x3_s1_fused_weights, 1, 1, 2, 0, {3, 2, 1, 0}, "xyio", "oixy?", {}), - FMT_TRAITS(winograd_6x3_s1_fused_weights, 1, 1, 2, 0, {3, 2, 1, 0}, "xyio", "oixy?", {}), - FMT_TRAITS(image_2d_weights_winograd_6x3_s1_fbxyb, 1, 1, 2, 0, {3, 2, 1, 0}, "xyio", "oixy?", {}), - FMT_TRAITS(image_2d_weights_winograd_6x3_s1_xfbyb, 1, 1, 2, 0, {3, 2, 1, 0}, "xyio", "oixy?", {}), - FMT_TRAITS(image_2d_weights_c4_fyx_b, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}), - FMT_TRAITS(image_2d_weights_c1_b_fyx, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}), - FMT_TRAITS(lstm_weights_dio, 1, 1, 2, 0, {0, 1, 3, 2}, "oixy", "oixy?", {}), - FMT_TRAITS(os_is_yx_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{1, 8}, {0, 8}, {1, 4}}), - FMT_TRAITS(os_is_yx_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{1, 8}, {0, 16}, {1, 4}}), - FMT_TRAITS(os_is_yx_isa8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}), - FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}), + FMT_TRAITS(o_is_yx_isv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 16}}), + FMT_TRAITS(os_yxi_osv16, 1, 1, 2, 0, {0, 2, 3, 1}, "oyxi", "oixy", {{0, 16}}), + FMT_TRAITS(os_iyx_osv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 16}}), + FMT_TRAITS(os_iyx_osv32, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}}), + FMT_TRAITS(os_iyx_osv64, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 64}}), + FMT_TRAITS(winograd_2x3_s1_weights, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {}), + FMT_TRAITS(winograd_2x3_s1_fused_weights, 1, 1, 2, 0, {3, 2, 1, 0}, "xyio", "oixy", {}), + FMT_TRAITS(winograd_6x3_s1_fused_weights, 1, 1, 2, 0, {3, 2, 1, 0}, "xyio", "oixy", {}), + FMT_TRAITS(image_2d_weights_winograd_6x3_s1_fbxyb, 1, 1, 2, 0, {3, 2, 1, 0}, "xyio", "oixy", {}), + FMT_TRAITS(image_2d_weights_winograd_6x3_s1_xfbyb, 1, 1, 2, 0, {3, 2, 1, 0}, "xyio", "oixy", {}), + FMT_TRAITS(image_2d_weights_c4_fyx_b, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {}), + FMT_TRAITS(image_2d_weights_c1_b_fyx, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {}), + FMT_TRAITS(lstm_weights_dio, 1, 1, 2, 0, {0, 1, 3, 2}, "oixy", "oixy", {}), + FMT_TRAITS(os_is_yx_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 8}, {1, 4}}), + FMT_TRAITS(os_is_yx_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 16}, {1, 4}}), + FMT_TRAITS(os_is_yx_isa8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {}), + FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}), FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}), FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}), FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}), @@ -106,25 +106,25 @@ static const std::map format_traits_map { FMT_TRAITS(os_is_zyx_osa2_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 2}, {1, 8}, {0, 8}, {1, 2}}), FMT_TRAITS(os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}), FMT_TRAITS(os_is_zyx_isa8_osv16_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 4}}), - FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 32}, {1, 32}}), + FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 32}}), FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv4_swizzled_by_4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 32}}), FMT_TRAITS(is_os_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}), - FMT_TRAITS(is_os_yx_isa2_osa8_isv8_osv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 2}, {0, 8}, {1, 8}, {0, 2}}), - FMT_TRAITS(is_os_yx_isa4_osa8_isv8_osv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 4}, {0, 8}, {1, 8}, {0, 4}}), - FMT_TRAITS(is_o_yx_isv32, 1, 1, 2, 0, {1, 0, 2, 3}, "oyxi", "oixy?", {{1, 32}}), - FMT_TRAITS(is_o32_yx_isv32_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy?", {}), - FMT_TRAITS(os_is_y_x8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy?", {}), - FMT_TRAITS(os_is_y_x8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy?", {}), - FMT_TRAITS(os_is_yx_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 16}, {1, 4}}), + FMT_TRAITS(is_os_yx_isa2_osa8_isv8_osv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy", {{1, 2}, {0, 8}, {1, 8}, {0, 2}}), + FMT_TRAITS(is_os_yx_isa4_osa8_isv8_osv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy", {{1, 4}, {0, 8}, {1, 8}, {0, 4}}), + FMT_TRAITS(is_o_yx_isv32, 1, 1, 2, 0, {1, 0, 2, 3}, "oyxi", "oixy", {{1, 32}}), + FMT_TRAITS(is_o32_yx_isv32_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy", {}), + FMT_TRAITS(os_is_y_x8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy", {}), + FMT_TRAITS(os_is_y_x8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy", {}), + FMT_TRAITS(os_is_yx_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 16}, {1, 4}}), FMT_TRAITS(os_is_yx_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 8}, {1, 4}}), FMT_TRAITS(os_is_zyx_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 8}, {1, 4}}), FMT_TRAITS(os_is_yx_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 8}, {1, 2}}), FMT_TRAITS(os_is_zyx_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 8}, {1, 2}}), FMT_TRAITS(os_is_zyx_osv16_isv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 16}, {1, 16}}), - FMT_TRAITS(os_is_yx_osv32_isv4_swizzled_by_2, 1, 1, 2, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 32}, {1, 4}}), - FMT_TRAITS(os_is_yx_osv32_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 32}, {1, 4}}), + FMT_TRAITS(os_is_yx_osv32_isv4_swizzled_by_2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 4}}), + FMT_TRAITS(os_is_yx_osv32_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 4}}), FMT_TRAITS(os_is_zyx_osv32_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 4}}), - FMT_TRAITS(os_is_yx_osv32_isv32p, 1, 1, 1, 0, {0, 1, 2, 3}, "oixy", "oixy?", {}), + FMT_TRAITS(os_is_yx_osv32_isv32p, 1, 1, 1, 0, {0, 1, 2, 3}, "oiyx", "oixy", {}), FMT_TRAITS(os_is_zyx_isv16_osv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 16}, {0, 16}}), FMT_TRAITS(is_os_zyx_isv16_osv16, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 16}, {0, 16}}), FMT_TRAITS(is_os_yx_isv16_osv16, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 16}}), @@ -134,11 +134,11 @@ static const std::map format_traits_map { FMT_TRAITS(os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}), FMT_TRAITS(os_is_zyx_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 2}}), FMT_TRAITS(os_is_zyx_isa8_osv16_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 4}}), - FMT_TRAITS(is_os_yx_isa8_osv8_isv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 8}, {0, 8}, {1, 2}}), - FMT_TRAITS(is_os_yx_isa8_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 8}, {0, 8}, {1, 4}}), - FMT_TRAITS(os_is_yx_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{1, 8}, {0, 8}, {1, 4}}), - FMT_TRAITS(os_is_yx_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{1, 8}, {0, 8}, {1, 2}}), - FMT_TRAITS(os_is_osv32_isv32_swizzled_by_4, 1, 1, 0, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 32}, {1, 32}}), + FMT_TRAITS(is_os_yx_isa8_osv8_isv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy", {{1, 8}, {0, 8}, {1, 2}}), + FMT_TRAITS(is_os_yx_isa8_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy", {{1, 8}, {0, 8}, {1, 4}}), + FMT_TRAITS(os_is_yx_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 8}, {1, 4}}), + FMT_TRAITS(os_is_yx_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 8}, {1, 2}}), + FMT_TRAITS(os_is_osv32_isv32_swizzled_by_4, 1, 1, 0, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 32}}), FMT_TRAITS(os_is_zyx_isv8_osv16_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 2}}), FMT_TRAITS(os_zyxi_osv16, 1, 1, 3, 0, {0, 2, 3, 4, 1}, "ozyxi", "oixyz", {{0, 16}}), FMT_TRAITS(os_is_yx_isv8_osv16_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 16}, {1, 2}}), @@ -152,8 +152,8 @@ static const std::map format_traits_map { FMT_TRAITS(iy_xs_os_xsv2_osv8__ao32, 1, 1, 2, 0, {1, 2, 3, 0}, "iyxo", "oixy", {{2, 2}, {0, 8}}), FMT_TRAITS(iy_xs_os_xsv2_osv16__ao32, 1, 1, 2, 0, {1, 2, 3, 0}, "iyxo", "oixy", {{2, 2}, {0, 16}}), FMT_TRAITS(os_i_yxs_osv4_yxsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 4}}), - FMT_TRAITS(os_i_osv16__ai8, 1, 1, 0, 0, {0, 1}, "oi", "oi??", {{1, 8}, {0, 16}}), - FMT_TRAITS(os_i_osv8__ai8, 1, 1, 0, 0, {0, 1}, "oi", "oi??", {{1, 8}, {0, 8}}), + FMT_TRAITS(os_i_osv16__ai8, 1, 1, 0, 0, {0, 1}, "oi", "oi", {{1, 8}, {0, 16}}), + FMT_TRAITS(os_i_osv8__ai8, 1, 1, 0, 0, {0, 1}, "oi", "oi", {{1, 8}, {0, 8}}), FMT_TRAITS(os_y_is_x_osv8_isv2, 1, 1, 2, 0, {0, 2, 1, 3}, "oyix", "oixy", {{0, 8}, {1, 2}}), FMT_TRAITS(os_y_is_x_osv8_isv4, 1, 1, 2, 0, {0, 2, 1, 3}, "oyix", "oixy", {{0, 8}, {1, 4}}), FMT_TRAITS(os_yx_is_osv8_isv2, 1, 1, 2, 0, {0, 2, 3, 1}, "oyxi", "oixy", {{0, 8}, {1, 2}}), From b69c11d8efdb517c043eb85cc1eb9987748d88a0 Mon Sep 17 00:00:00 2001 From: Sebastian Golebiewski Date: Wed, 14 Jun 2023 13:03:59 +0200 Subject: [PATCH 06/11] [DOCS] Restyling tabs --- docs/_static/css/custom.css | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css index 991092eab7b456..2a5bc5d9639433 100644 --- a/docs/_static/css/custom.css +++ b/docs/_static/css/custom.css @@ -21,7 +21,40 @@ pre { white-space: pre-wrap; word-wrap: break-word; } - +/* Sphinx-design tabs override */ + +.sd-tab-set>input:checked+label { + border-color: var(--sd-color-tabs-underline-inactive); + color: var(--sd-color-info-text); + background-color: rgb(0 104 181)!important; +} +.sd-tab-set>input:checked+label:hover { + color: --sd-color-info-text; + background-color: rgb(0,74,134)!important; +} +.sd-tab-set>input:not(:checked)+label:hover { + color: var(--sd-color-black)!important; + background-color: rgb(245, 245, 245)!important; + border-color: var(--sd-color-card-header)!important; +} +.sd-tab-set>label { + border-bottom: 0.125rem solid transparent; + margin-right: 10px!important; + margin-bottom: 8px; + color: var(--sd-color-black)!important; + border-color: var(--sd-color-tabs-underline-inactive); + cursor: pointer; + font-size: var(--sd-fontsize-tabs-label); + font-weight: 400!important; + padding: 5px 16px 2px!important; + transition: color 250ms; + width: auto; + z-index: 1; +} +.sd-tab-content { + box-shadow:none!important; + border-top: solid 2px var(--sd-color-tabs-overline)!important; +} /* Navigation panels override */ /* =================================================== */ From d3461074ea6889497f55056d4c912d970151e767 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 14 Jun 2023 15:08:45 +0400 Subject: [PATCH 07/11] [PT FE]: support aten::t and inplace tril/triu (#18040) --- src/frontends/pytorch/src/op/transpose.cpp | 10 +++ src/frontends/pytorch/src/op_table.cpp | 5 ++ .../pytorch_tests/test_transpose.py | 61 ++++++++++++++++++- tests/layer_tests/pytorch_tests/test_trilu.py | 51 +++++++++++++++- 4 files changed, 125 insertions(+), 2 deletions(-) diff --git a/src/frontends/pytorch/src/op/transpose.cpp b/src/frontends/pytorch/src/op/transpose.cpp index 9a6cddb3ffb896..23cbb0d25d7622 100644 --- a/src/frontends/pytorch/src/op/transpose.cpp +++ b/src/frontends/pytorch/src/op/transpose.cpp @@ -50,6 +50,16 @@ OutputVector translate_transpose(const NodeContext& context) { return {context.mark_node(std::make_shared(context.get_input(0), scatter))}; }; +OutputVector translate_t(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto input = context.get_input(0); + if (input.get_partial_shape().rank().is_dynamic() || input.get_partial_shape().rank().get_length() < 2) { + return {input}; + } + auto dims = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0})); + return {context.mark_node(std::make_shared(input, dims))}; +} + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 1b7e76f947420b..fdf52d511f0485 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -130,6 +130,7 @@ OP_CONVERTER(translate_square); OP_CONVERTER(translate_squeeze); OP_CONVERTER(translate_sub); OP_CONVERTER(translate_sum); +OP_CONVERTER(translate_t); OP_CONVERTER(translate_to); OP_CONVERTER(translate_topk); OP_CONVERTER(translate_transpose); @@ -348,6 +349,8 @@ const std::map get_supported_ops() { {"aten::squeeze", op::translate_squeeze}, {"aten::sub", op::translate_sub}, {"aten::sum", op::translate_sum}, + {"aten::t", op::translate_t}, + {"aten::t_", op::inplace_op}, {"aten::tan", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::tan_", op::inplace_op>}, {"aten::tanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, @@ -357,7 +360,9 @@ const std::map get_supported_ops() { {"aten::topk", op::translate_topk}, {"aten::transpose", op::translate_transpose}, {"aten::tril", op::translate_tril}, + {"aten::tril_", op::inplace_op}, {"aten::triu", op::translate_triu}, + {"aten::triu_", op::inplace_op}, {"aten::type_as", op::translate_1to1_match_2_inputs}, // TODO: overflow semantics is different {"aten::unflatten", op::translate_unflatten}, diff --git a/tests/layer_tests/pytorch_tests/test_transpose.py b/tests/layer_tests/pytorch_tests/test_transpose.py index c1d0bd4f3e7fbb..024cad110ef480 100644 --- a/tests/layer_tests/pytorch_tests/test_transpose.py +++ b/tests/layer_tests/pytorch_tests/test_transpose.py @@ -31,6 +31,65 @@ def forward(self, x): @pytest.mark.parametrize("dim1", [0, 1, 2, 3, -1, -2, -3, -4]) @pytest.mark.nightly @pytest.mark.precommit - def test_relu(self, dim0, dim1, ie_device, precision, ir_version): + def test_transpose(self, dim0, dim1, ie_device, precision, ir_version): self._test(*self.create_model(dim0, dim1), ie_device, precision, ir_version) + + +class TestTSmall(PytorchLayerTest): + def _prepare_input(self, num_dims=2, input_dtype="float32"): + import numpy as np + shape = (2, 3) + if num_dims == 0: + return (np.array(num_dims).astype(input_dtype), ) + return (np.random.randn(*shape[:num_dims]).astype(input_dtype),) + + def create_model(self, num_dims=2, inplace=False): + import torch + + class aten_transpose(torch.nn.Module): + def __init__(self, num_dims, inplace): + super(aten_transpose, self).__init__() + if num_dims == 2: + self.forward = self.forward_2d if not inplace else self.forward_2d_inplace + elif num_dims == 1: + self.forward = self.forward_1d if not inplace else self.forward_1d_inplace + else: + if inplace: + self.forward = self.forward_inplace + + def forward_2d(self, x): + x = torch.reshape(x, (2, -1)) + return x.t(), x + + def forward_2d_inplace(self, x): + x = torch.reshape(x, (2, -1)) + return x.t_(), x + + def forward_1d(self, x): + x = torch.reshape(x, (-1, )) + return x.t(), x + + def forward_1d_inplace(self, x): + x = torch.reshape(x, (-1, )) + return x.t_(), x + + def forward(self, x): + return x.t(), x + + def forward_inplace(self, x): + return x.t_(), x + + ref_net = None + + return aten_transpose(num_dims, inplace), ref_net, "aten::t" if not inplace else "aten::t_" + + @pytest.mark.parametrize("num_dims", [0, 1, 2]) + @pytest.mark.parametrize("input_dtype", ["float32", "int32"]) + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_t_small(self, num_dims, input_dtype, inplace, ie_device, precision, ir_version): + self._test(*self.create_model(num_dims, inplace), + ie_device, precision, ir_version, + kwargs_to_prepare_input={"num_dims": num_dims, "input_dtype": input_dtype}) diff --git a/tests/layer_tests/pytorch_tests/test_trilu.py b/tests/layer_tests/pytorch_tests/test_trilu.py index 14088ebbf956be..28842e101ce6da 100644 --- a/tests/layer_tests/pytorch_tests/test_trilu.py +++ b/tests/layer_tests/pytorch_tests/test_trilu.py @@ -42,4 +42,53 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version): - self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version, kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype}) \ No newline at end of file + self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version, + kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype}) + + +class TestTriuTrilTensor(PytorchLayerTest): + def _prepare_input(self, shape, dtype): + import numpy as np + return (np.random.randn(*shape).astype(dtype),) + + def create_model(self, op, diagonal): + + import torch + + class aten_trilu(torch.nn.Module): + def __init__(self, op, diagonal): + super(aten_trilu, self).__init__() + op_map = { + "tril": self.tril, + "tril_": self.tril_, + "triu": self.triu, + "triu_": self.triu_ + } + self.diagonal = diagonal + self.forward = op_map[op] + + def tril(self, x): + return x.tril(self.diagonal), x + + def tril_(self, x): + return x.tril_(self.diagonal), x + + def triu(self, x): + return x.triu(self.diagonal), x + + def triu_(self, x): + return x.triu_(self.diagonal), x + + ref_net = None + + return aten_trilu(op, diagonal), ref_net, f"aten::{op}" + + @pytest.mark.parametrize("input_shape", [(5, 5), (6, 4), (4, 6)]) + @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8", "uint8", "bool"]) + @pytest.mark.parametrize("diagonal", [0, 1, 2, -1, -2]) + @pytest.mark.parametrize("op", ["triu", "tril", "triu_", "tril_"]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version): + self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version, + kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype}) \ No newline at end of file From 1761427ab15250365549577f729368da982c0495 Mon Sep 17 00:00:00 2001 From: Andrei Gorbachev Date: Wed, 14 Jun 2023 12:58:49 +0100 Subject: [PATCH 08/11] fixed fp16 x fp16 overflow in NonMaxSuppression (#18038) --- .../cl_kernels/non_max_suppression_gpu_ref.cl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/non_max_suppression_gpu_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/non_max_suppression_gpu_ref.cl index e2c17a8c7fe07a..36651d8773fe6c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/non_max_suppression_gpu_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/non_max_suppression_gpu_ref.cl @@ -80,8 +80,8 @@ inline float FUNC(intersectionOverUnion)(const COORD_TYPE_4 boxA, const COORD_TY { #if BOX_ENCODING == 0 /// CORNER - const COORD_TYPE areaA = (boxA[3] - boxA[1]) * (boxA[2] - boxA[0]); - const COORD_TYPE areaB = (boxB[3] - boxB[1]) * (boxB[2] - boxB[0]); + const float areaA = convert_float(boxA[3] - boxA[1]) * convert_float(boxA[2] - boxA[0]); + const float areaB = convert_float(boxB[3] - boxB[1]) * convert_float(boxB[2] - boxB[0]); const COORD_TYPE intersection_ymin = max(boxA[0], boxB[0]); const COORD_TYPE intersection_xmin = max(boxA[1], boxB[1]); @@ -89,8 +89,8 @@ inline float FUNC(intersectionOverUnion)(const COORD_TYPE_4 boxA, const COORD_TY const COORD_TYPE intersection_xmax = min(boxA[3], boxB[3]); #else /// CENTER - const COORD_TYPE areaA = boxA[3] * boxA[2]; - const COORD_TYPE areaB = boxB[3] * boxB[2]; + const float areaA = convert_float(boxA[3]) * convert_float(boxA[2]); + const float areaB = convert_float(boxB[3]) * convert_float(boxB[2]); const COORD_TYPE halfWidthA = boxA[2] / 2; const COORD_TYPE halfHeightA = boxA[3] / 2; const COORD_TYPE halfWidthB = boxB[2] / 2; @@ -105,10 +105,10 @@ inline float FUNC(intersectionOverUnion)(const COORD_TYPE_4 boxA, const COORD_TY if (areaA <= 0.0f || areaB <= 0.0f) return 0.0f; - const COORD_TYPE intersection_area = max(intersection_xmax - intersection_xmin, TO_COORD_TYPE(0.f)) * - max(intersection_ymax - intersection_ymin, TO_COORD_TYPE(0.f)); - const COORD_TYPE union_area = areaA + areaB - intersection_area; - return convert_float(intersection_area / union_area); + const float intersection_area = convert_float(max(intersection_xmax - intersection_xmin, TO_COORD_TYPE(0.f))) * + convert_float(max(intersection_ymax - intersection_ymin, TO_COORD_TYPE(0.f))); + const float union_area = areaA + areaB - intersection_area; + return intersection_area / union_area; } inline float FUNC(scaleIOU)(float iou, float iou_threshold, float scale) From 92ac04dcaccac2410758ff64bbc42284d2ada420 Mon Sep 17 00:00:00 2001 From: Sebastian Golebiewski Date: Wed, 14 Jun 2023 14:01:33 +0200 Subject: [PATCH 09/11] [DOCS] Restyling tabs fix Fixing style of active tabs. --- docs/_static/css/custom.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css index 2a5bc5d9639433..fe017816f7a428 100644 --- a/docs/_static/css/custom.css +++ b/docs/_static/css/custom.css @@ -25,7 +25,7 @@ pre { .sd-tab-set>input:checked+label { border-color: var(--sd-color-tabs-underline-inactive); - color: var(--sd-color-info-text); + color: var(--sd-color-info-text)!important; background-color: rgb(0 104 181)!important; } .sd-tab-set>input:checked+label:hover { From 277e759dcd4954b342a40b57c0d08e7ccb877ee8 Mon Sep 17 00:00:00 2001 From: Wanglei Shen Date: Wed, 14 Jun 2023 21:19:39 +0800 Subject: [PATCH 10/11] rebase (#17212) --- .../behavior/ov_executable_network/properties.cpp | 4 ++-- .../behavior/ov_plugin/properties_tests.cpp | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/plugins/intel_cpu/tests/functional/behavior/ov_executable_network/properties.cpp b/src/plugins/intel_cpu/tests/functional/behavior/ov_executable_network/properties.cpp index 28c48deee8bfb3..b4436bf73a7e5e 100644 --- a/src/plugins/intel_cpu/tests/functional/behavior/ov_executable_network/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/behavior/ov_executable_network/properties.cpp @@ -96,7 +96,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriori TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriorityThanLatencyHint) { ov::Core ie; - int32_t streams = 4; // latency hint should apply lower number of streams + int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams int32_t value = 0; ASSERT_NO_THROW(ie.set_property(deviceName, ov::num_streams(streams))); @@ -109,7 +109,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriori TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPriorityThanLatencyHint) { ov::Core ie; - int32_t streams = 4; // latency hint should apply lower number of streams + int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams int32_t value = 0; ASSERT_NO_THROW(ie.set_property(deviceName, ov::hint::performance_mode(ov::hint::PerformanceMode::LATENCY))); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/behavior/ov_plugin/properties_tests.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/behavior/ov_plugin/properties_tests.cpp index 459ac5f46af57b..be02c88530ebc0 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/behavior/ov_plugin/properties_tests.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/behavior/ov_plugin/properties_tests.cpp @@ -197,20 +197,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_OVClassSetDevicePriorityConfigPropsTest, ::testing::ValuesIn(multiConfigs))); const std::vector configsDeviceProperties = { - {ov::device::properties("CPU", ov::num_streams(3))}, - {ov::device::properties(ov::AnyMap{{"CPU", ov::AnyMap{ov::num_streams(3)}}})}}; + {ov::device::properties("CPU", ov::num_streams(2))}, + {ov::device::properties(ov::AnyMap{{"CPU", ov::AnyMap{ov::num_streams(2)}}})}}; const std::vector configsDevicePropertiesDouble = { - {ov::device::properties("CPU", ov::num_streams(3)), ov::num_streams(5)}, - {ov::device::properties("CPU", ov::num_streams(3)), + {ov::device::properties("CPU", ov::num_streams(2)), ov::num_streams(5)}, + {ov::device::properties("CPU", ov::num_streams(2)), ov::device::properties(ov::AnyMap{{"CPU", ov::AnyMap{ov::num_streams(7)}}}), ov::num_streams(5)}, - {ov::device::properties("CPU", ov::num_streams(3)), ov::device::properties("CPU", ov::num_streams(5))}, - {ov::device::properties("CPU", ov::num_streams(3)), + {ov::device::properties("CPU", ov::num_streams(2)), ov::device::properties("CPU", ov::num_streams(5))}, + {ov::device::properties("CPU", ov::num_streams(1)), ov::device::properties(ov::AnyMap{{"CPU", ov::AnyMap{ov::num_streams(5)}}})}, - {ov::device::properties(ov::AnyMap{{"CPU", ov::AnyMap{ov::num_streams(3)}}}), + {ov::device::properties(ov::AnyMap{{"CPU", ov::AnyMap{ov::num_streams(1)}}}), ov::device::properties(ov::AnyMap{{"CPU", ov::AnyMap{ov::num_streams(5)}}})}}; + // IE Class load and check network with ov::device::properties INSTANTIATE_TEST_SUITE_P(smoke_CPU_OVClassCompileModelAndCheckSecondaryPropertiesTest, OVClassCompileModelAndCheckSecondaryPropertiesTest, From 5993c4942ad315d413eab7b282a237af3382bc57 Mon Sep 17 00:00:00 2001 From: Anastasiia Pnevskaia Date: Wed, 14 Jun 2023 16:43:40 +0200 Subject: [PATCH 11/11] Fixed load_by_model() test. (#18060) --- src/bindings/python/tests/test_frontend/test_frontendmanager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bindings/python/tests/test_frontend/test_frontendmanager.py b/src/bindings/python/tests/test_frontend/test_frontendmanager.py index e1ed922bf0b690..6b3f922f7e424c 100644 --- a/src/bindings/python/tests/test_frontend/test_frontendmanager.py +++ b/src/bindings/python/tests/test_frontend/test_frontendmanager.py @@ -115,7 +115,7 @@ def __str__(self): @mock_needed def test_load_by_model(): clear_all_stat() - fe = fem.load_by_model(model_path="abc.test_mock_py_mdl") + fe = fem.load_by_model(model="abc.test_mock_py_mdl") assert fe is not None assert fe.get_name() == MOCK_PY_FRONTEND_NAME stat = get_fe_stat()