From 9d4d38ec2be7e9c760e672b4c897a4338fddc058 Mon Sep 17 00:00:00 2001 From: "Min, Byungil" Date: Fri, 26 May 2023 07:53:46 +0900 Subject: [PATCH] [GPU] Optimize out Gather by converting to implicit crop + Changed Gather if it divides input tensor along batch axis + Converted Gather to cldnn Crop in CreateGatherOpBase + Added implicit Crop condition for batch axis Signed-off-by: Min, Byungil --- .../graph_optimizer/prepare_buffer_fusing.cpp | 175 ++++++++++++------ .../intel_gpu/src/plugin/ops/gather.cpp | 56 ++++-- .../single_layer_tests/gather.cpp | 19 ++ .../tests/unit/shape_infer/gather_si_test.cpp | 9 + .../tests/unit/test_cases/crop_gpu_test.cpp | 30 +++ .../tests/unit/test_cases/gather_gpu_test.cpp | 40 ++++ 6 files changed, 263 insertions(+), 66 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index 7532a0d842c3eb..d8ed7f897cb83e 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -317,6 +317,72 @@ static bool can_reshape_be_optimized(const reshape_node& node) { return node.is_in_place() && !node.has_fused_primitives(); } +static bool is_optimizable_padding_for_crop(const crop_node& node) { + const auto& crop_layout = node.get_output_layout(); + auto input_layout = node.get_dependency(0).get_output_layout(); + auto crop_prim = node.get_primitive(); + auto opt_lower_pad = crop_prim->offsets.feature[0]; + auto opt_upper_pad = input_layout.feature() - crop_prim->offsets.feature[0] - crop_layout.get_tensor().feature[0]; + + // do not optimize crop if paddings are not properly aligned + for (auto& usr : node.get_users()) { + auto usr_layout = usr->get_output_layout(); + if (usr_layout.format == format::b_fs_yx_fsv16 && + (opt_lower_pad % 16 != 0 || opt_upper_pad % 16 != 0)) + return false; + + if (input_layout.data_padding.lower_size().batch[0] != 0 || input_layout.data_padding.upper_size().batch[0] != 0 || + input_layout.data_padding.lower_size().spatial[0] != 0 || input_layout.data_padding.upper_size().spatial[0] != 0 || + input_layout.data_padding.lower_size().spatial[1] != 0 || input_layout.data_padding.upper_size().spatial[1] != 0) + return false; + + // oneDNN doesn't support paddings + if (usr->get_preferred_impl_type() == impl_types::onednn) + return false; + } + + return true; +} + +static bool can_crop_be_optimized_along_feature(const crop_node& node) { + const auto& crop_layout = node.get_output_layout(); + auto format = crop_layout.format; + auto input_layout = node.get_dependency(0).get_output_layout(); + const auto& crop_size = crop_layout.get_tensor(); + const auto& out_pad = crop_layout.data_padding; + + if (format == format::bfyx && crop_size.batch[0] == input_layout.batch() && + crop_size.spatial[0] == input_layout.spatial(0) && + crop_size.spatial[1] == input_layout.spatial(1) && out_pad.lower_size().feature[0] == 0 && + out_pad.upper_size().feature[0] == 0 && out_pad.lower_size().batch[0] == 0 && + out_pad.upper_size().batch[0] == 0 && out_pad.lower_size().spatial[0] == 0 && + out_pad.lower_size().spatial[1] == 0 && out_pad.upper_size().spatial[0] == 0 && + out_pad.upper_size().spatial[1] == 0) { + return true; + } + + return false; +} + +static bool can_crop_be_optimized_along_batch(const crop_node& node) { + const auto& crop_layout = node.get_output_layout(); + auto format = crop_layout.format; + auto input_layout = node.get_dependency(0).get_output_layout(); + const auto crop_shape = crop_layout.get_ordered_dims(); + const auto input_shape = input_layout.get_ordered_dims(); + const auto& in_padding = input_layout.data_padding; + const auto& out_padding = crop_layout.data_padding; + + // Check format's order is 'bxxx' and only batch size is different + if (format::is_simple_data_format(format) && format::traits(format)._order[0] == 0 && + std::equal(input_shape.begin()+1, input_shape.end(), crop_shape.begin()+1) && + !out_padding && !in_padding) { + return true; + } + + return false; +} + static void propagate_padding_to_opt_out_users(program_node& node, cldnn::padding padding_data) { if (padding_data == cldnn::padding()) return; @@ -366,6 +432,7 @@ void prepare_buffer_fusing::run(program& p) { if (!can_optimize(node)) continue; + // zero copy program_helpers::do_for_types(*node, [&p](crop_node& node) { // if the node is marked as network output, prevent optimizations which would affect a form of its output, @@ -392,56 +459,38 @@ void prepare_buffer_fusing::run(program& p) { if (p.is_loop_body() && node.get_dependency(0).is_type()) { return; } - // optimization is available for cropping across depth(features) only + + // optimization is available for cropping across depth(features) or batch // if output padding has defined padding across features already it wouldn't // work because it expect to have zeros in the padded area. + if (!is_optimizable_padding_for_crop(node)) + return; + const auto& crop_layout = node.get_output_layout(); - auto format = crop_layout.format; - auto crop_prim = node.get_primitive(); - auto input_layout = node.get_dependency(0).get_output_layout(); const auto& crop_size = crop_layout.get_tensor(); - const auto& out_padd = crop_layout.data_padding; - auto opt_lower_pad = crop_prim->offsets.feature[0]; - auto opt_upper_pad = input_layout.feature() - crop_prim->offsets.feature[0] - crop_size.feature[0]; - - // do not optimize crop if paddings are not properly aligned - for (auto& usr : node.get_users()) { - auto usr_layout = usr->get_output_layout(); - if (usr_layout.format == format::b_fs_yx_fsv16 && - (opt_lower_pad % 16 != 0 || opt_upper_pad % 16 != 0)) - return; - if (input_layout.data_padding.lower_size().batch[0] != 0 || input_layout.data_padding.upper_size().batch[0] != 0 || - input_layout.data_padding.lower_size().spatial[0] != 0 || input_layout.data_padding.upper_size().spatial[0] != 0 || - input_layout.data_padding.lower_size().spatial[1] != 0 || input_layout.data_padding.upper_size().spatial[1] != 0) - return; - // oneDNN doesn't support paddings - if (usr->get_preferred_impl_type() == impl_types::onednn) - return; - } - - if (format == format::bfyx && crop_size.batch[0] == input_layout.batch() && - crop_size.spatial[0] == input_layout.spatial(0) && - crop_size.spatial[1] == input_layout.spatial(1) && out_padd.lower_size().feature[0] == 0 && - out_padd.upper_size().feature[0] == 0 && out_padd.lower_size().batch[0] == 0 && - out_padd.upper_size().batch[0] == 0 && out_padd.lower_size().spatial[0] == 0 && - out_padd.lower_size().spatial[1] == 0 && out_padd.upper_size().spatial[0] == 0 && - out_padd.upper_size().spatial[1] == 0) { - // Regular crop - // crop input buffer - // |___________data____________| - // - // crop output buffer - // |-------->| offsets[f] |<--| - // |_____data____| - // <------------> - // reference size - // - // In-place crop - // crop output buffer - // |_low_pad_|__data_size__|___|<-upper pad + const auto& out_pad = crop_layout.data_padding; + auto input_layout = node.get_dependency(0).get_output_layout(); + auto crop_prim = node.get_primitive(); - // feature num of pad should be accumulated if dep has been optimized out. + // Regular crop + // crop input buffer + // |___________data____________| + // + // crop output buffer + // |-------->| offsets[f] |<--| + // |_____data____| + // <------------> + // reference size + // + // In-place crop + // crop output buffer + // |_low_pad_|__data_size__|___|<-upper pad + if (can_crop_be_optimized_along_feature(node)) { + auto crop_prim = node.get_primitive(); + auto opt_lower_pad = crop_prim->offsets.feature[0]; + auto opt_upper_pad = input_layout.feature() - crop_prim->offsets.feature[0] - crop_size.feature[0]; auto& dep = node.get_dependency(0); + // feature num of pad should be accumulated if dep has been optimized out. if (dep.is_type() && dep.can_be_optimized()) { auto dep_pad = dep.get_output_layout().data_padding; OPENVINO_ASSERT( @@ -454,18 +503,36 @@ void prepare_buffer_fusing::run(program& p) { opt_upper_pad += dep_pad.upper_size().feature[0]; } + // set padding node.set_output_padding( - padding({out_padd.lower_size().batch[0], - opt_lower_pad, - out_padd.lower_size().spatial[0], - out_padd.lower_size().spatial[1]}, - {out_padd.upper_size().batch[0], - opt_upper_pad, - out_padd.upper_size().spatial[0], - out_padd.upper_size().spatial[1]})); - node.can_be_optimized(true); - propagate_padding_to_opt_out_users(node, node.get_output_layout().data_padding); + padding({out_pad.lower_size().batch[0], + opt_lower_pad, + out_pad.lower_size().spatial[0], + out_pad.lower_size().spatial[1]}, + {out_pad.upper_size().batch[0], + opt_upper_pad, + out_pad.upper_size().spatial[0], + out_pad.upper_size().spatial[1]})); + } else if (can_crop_be_optimized_along_batch(node)) { + auto crop_prim = node.get_primitive(); + auto opt_lower_pad = crop_prim->offsets.batch[0]; + auto opt_upper_pad = input_layout.batch() - crop_prim->offsets.batch[0] - crop_size.batch[0]; + + auto new_padding = padding({opt_lower_pad, + out_pad.lower_size().feature[0], + out_pad.lower_size().spatial[0], + out_pad.lower_size().spatial[1]}, + {opt_upper_pad, + out_pad.upper_size().feature[0], + out_pad.upper_size().spatial[0], + out_pad.upper_size().spatial[1]}); + node.set_output_padding(new_padding); + } else { + return; } + + node.can_be_optimized(true); + propagate_padding_to_opt_out_users(node, node.get_output_layout().data_padding); } }); } diff --git a/src/plugins/intel_gpu/src/plugin/ops/gather.cpp b/src/plugins/intel_gpu/src/plugin/ops/gather.cpp index 16dc84eac3bf19..fe43a793ee5e7d 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/gather.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/gather.cpp @@ -4,12 +4,14 @@ #include "intel_gpu/plugin/program.hpp" #include "intel_gpu/plugin/common_utils.hpp" +#include "transformations/utils/utils.hpp" #include "ngraph/op/gather.hpp" #include "intel_gpu/primitives/gather.hpp" #include "intel_gpu/primitives/reorder.hpp" #include "intel_gpu/primitives/reshape.hpp" +#include "intel_gpu/primitives/crop.hpp" using namespace InferenceEngine; namespace ov { @@ -44,12 +46,13 @@ void CreateGatherOpBase(Program& p, const std::shared_ptr& op, const int64_t } // Dynamic path will do shape infer internally, so no need to pass valid out shape for that case - ov::Shape out_shape = op->get_output_partial_shape(0).is_static() ? op->get_output_shape(0) : ov::Shape{}; + bool is_static = op->get_output_partial_shape(0).is_static(); + ov::Shape out_shape = is_static ? op->get_output_shape(0) : ov::Shape{}; // Update output_shape in case of scalar indice bool need_reshape = false; auto out_shape_original = out_shape; - if (!p.use_new_shape_infer() && op->get_output_partial_shape(0).is_static()) { + if (!p.use_new_shape_infer() && is_static) { auto input1_shape = op->get_input_shape(1); if (input1_shape.size() == 0 && batch_dim == 0) { need_reshape = true; @@ -77,21 +80,50 @@ void CreateGatherOpBase(Program& p, const std::shared_ptr& op, const int64_t } } - // gather + // Set layer name for Gather auto reshapeName = layerName + ""; if (need_reshape) { layerName = layerName + "_reshape_output"; } - auto gatherPrim = cldnn::gather(layerName, - reordered_inputs[0], - reordered_inputs[1], - axis, - out_shape, - batch_dim, - support_neg_ind); - - p.add_primitive(*op, gatherPrim); + // Check if Gather could be converted to other primitive + const auto input_shape = op->get_input_partial_shape(0); + const auto input_rank = input_shape.rank().get_length(); + const auto& indices = op->input_value(1); + if (is_static && axis == 0 && input_rank > 1 && indices.get_partial_shape().rank().get_length() == 0 && + std::equal(input_shape.begin()+1, input_shape.end(), out_shape.begin()+1)) { + // Gather -> Crop + // this Gather simply divides an input tensor along Batch axis + auto get_crop_layer_name = [&](std::string name, size_t idx)->std::string { + return (name + "/crop_" + std::to_string(idx)); + }; + + // Get indices info to calculate offset + const auto& indices_node = indices.get_node_shared_ptr(); + auto indices_constant = std::dynamic_pointer_cast(indices_node); + float result = 0.f; + ov::op::util::get_single_value(indices_constant, result); + + // Set tensors for crop shape and offset + InferenceEngine::SizeVector start_offset(input_shape.size()); + start_offset[0] = static_cast(result); + auto offsetTensor = tensor_from_dims(start_offset, 0); + auto outTensor = tensor_from_dims(out_shape, 1); + + // Create Crop + layerName = get_crop_layer_name(layerName, static_cast(result)); + auto cropPrim = cldnn::crop(layerName, reordered_inputs[0], outTensor, offsetTensor); + p.add_primitive(*op, cropPrim); + } else { + auto gatherPrim = cldnn::gather(layerName, + reordered_inputs[0], + reordered_inputs[1], + axis, + out_shape, + batch_dim, + support_neg_ind); + p.add_primitive(*op, gatherPrim); + } // Add reorder and reshape for scalar indice if (need_reshape) { diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/gather.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/gather.cpp index a7993fa9e2cdc9..ac4a3eca0637b9 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/gather.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/gather.cpp @@ -559,4 +559,23 @@ INSTANTIATE_TEST_SUITE_P( GatherLayerTest::getTestCaseName ); +const auto GatherAxes0Optimized = []() { + return testing::Combine(testing::ValuesIn({std::vector{4, 8, 2, 2}}), + testing::ValuesIn({std::vector{}}), + testing::ValuesIn({std::tuple{0, 0}}), + testing::ValuesIn(netPrecisionsFP32), + testing::Values(InferenceEngine::Precision::UNSPECIFIED), + testing::Values(InferenceEngine::Precision::UNSPECIFIED), + testing::Values(InferenceEngine::Layout::ANY), + testing::Values(InferenceEngine::Layout::ANY), + testing::Values(CommonTestUtils::DEVICE_GPU)); +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_Gather7Axes0Optimized, + Gather8IndiceScalarLayerTest, + GatherAxes0Optimized(), + Gather8IndiceScalarLayerTest::getTestCaseName +); + } // namespace diff --git a/src/plugins/intel_gpu/tests/unit/shape_infer/gather_si_test.cpp b/src/plugins/intel_gpu/tests/unit/shape_infer/gather_si_test.cpp index 26ac1a600151d1..aedcfb9d4dce5c 100644 --- a/src/plugins/intel_gpu/tests/unit/shape_infer/gather_si_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/shape_infer/gather_si_test.cpp @@ -61,4 +61,13 @@ INSTANTIATE_TEST_SUITE_P(smoke, gather_test, }, })); +INSTANTIATE_TEST_SUITE_P(optimized, gather_test, + testing::ValuesIn(std::vector{ + { + layout{ov::PartialShape{3, 4, 2, 2}, data_types::f32, format::bfyx}, layout{ov::PartialShape{1}, data_types::f32, format::bfyx}, + 0, 0, + layout{ov::PartialShape{1, 4, 2, 2}, data_types::f32, format::bfyx} + }, + })); + } // shape_infer_tests diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/crop_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/crop_gpu_test.cpp index 2fe73506993e42..c9772ff8d197fb 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/crop_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/crop_gpu_test.cpp @@ -1576,3 +1576,33 @@ TEST(crop_gpu, optimized_out_crop) { ASSERT_TRUE(all_primitives["crop1"] == "_optimized_"); ASSERT_TRUE(all_primitives["crop2"] == "_optimized_"); } + +TEST(crop_single_axis, simple_Baxis) { + auto& engine = get_test_engine(); + + auto input1 = engine.allocate_memory({ data_types::f32, format::bfyx, tensor{ 3, 2, 1, 2 } }); + + set_values(input1, { + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f + }); + + topology topology; + topology.add(input_layout("Input", input1->get_layout())); + topology.add(crop("crop", input_info("Input"), tensor{1, 2, 1, 2}, tensor(1, 0, 0, 0))); + topology.add(reorder("reorder", input_info("crop"), format::bfyx, data_types::i8)); + + network network(engine, topology, get_test_default_config(engine)); + + network.set_input_data("Input", input1); + + auto outputs = network.execute(); + + auto output = outputs.at("reorder").get_memory(); + cldnn::mem_lock output_ptr(output, get_test_stream()); + + std::vector expected_results = { + 5, 6, 7, 8 + }; +} diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/gather_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/gather_gpu_test.cpp index 57d3e982be977a..27c0b74204f296 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/gather_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/gather_gpu_test.cpp @@ -2097,3 +2097,43 @@ TEST(gather_gpu_u8, 322_axisF) { TEST(gather_gpu_u8, export_import) { test_gather_gpu_u8_322_axisF(true); } + +TEST(gather_single_axis, simple_Baxis) { + auto& engine = get_test_engine(); + + auto input1 = engine.allocate_memory({ data_types::f32, format::bfyx, tensor{ 3, 2, 1, 2 } }); // Dictionary + auto input2 = engine.allocate_memory({ data_types::i32, format::bfyx, tensor{ 1, 1, 1, 1 } }); // Indexes + int64_t axis = 0; + + set_values(input1, { + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f + }); + + set_values(input2, { + 1 + }); + + topology topology; + topology.add(input_layout("InputDictionary", input1->get_layout())); + topology.add(input_layout("InputText", input2->get_layout())); + topology.add( + gather("gather", input_info("InputDictionary"), input_info("InputText"), axis, ov::Shape{1, 2, 1, 2}) + ); + topology.add(reorder("reorder", input_info("gather"), format::bfyx, data_types::i8)); + + network network(engine, topology, get_test_default_config(engine)); + + network.set_input_data("InputDictionary", input1); + network.set_input_data("InputText", input2); + + auto outputs = network.execute(); + + auto output = outputs.at("reorder").get_memory(); + cldnn::mem_lock output_ptr(output, get_test_stream()); + + std::vector expected_results = { + 5, 6, 7, 8 + }; +}