Skip to content

Commit

Permalink
[GPU] selected format adjusts to the required input rank at get_prefe…
Browse files Browse the repository at this point in the history
…rred_format
  • Loading branch information
kelvinchoi-intel committed Jul 3, 2023
1 parent 6be030b commit db145a0
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/plugins/intel_gpu/src/graph/layout_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1647,12 +1647,12 @@ format layout_optimizer::get_preferred_format(program_node& node) {
node.set_preferred_input_fmt(i, fmt);
} else if (in_lay_rank != out_lay_rank) {
auto fmt = get_preferred_format(node.get_dependency(i));
// Check if selected format can be adjusted to the required output rank
// Check if selected format can be adjusted to the required input rank
// If no, use default fotmat instead
try {
format::adjust_to_rank(fmt, out_lay_rank);
format::adjust_to_rank(fmt, in_lay_rank);
} catch (ov::Exception&) {
fmt = format::get_default_format(out_lay_rank);
fmt = format::get_default_format(in_lay_rank);
}
node.set_preferred_input_fmt(i, fmt);
}
Expand Down
3 changes: 2 additions & 1 deletion src/plugins/intel_gpu/tests/unit/passes/handle_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ TEST(handle_reshape, reshape_input_reorder) {
// converts tensor to default format with rank = reshape_out_rank
// Likely in the future we'll update that reorder so it will use reshape_input_rank
// After that expected in format will be bfzyx
ASSERT_EQ(reshape_layout_in.format, format::bfyx);
// [Updated] get_preferred_format() updated to use 'in_lay_rank' instead of 'out_lay_rank' for preferred input format
ASSERT_EQ(reshape_layout_in.format, format::bfzyx);
ASSERT_EQ(reshape_layout_out.format, format::bfyx);

ov::PartialShape expected_out_shape{-1, 16, 64, 64};
Expand Down
40 changes: 40 additions & 0 deletions src/plugins/intel_gpu/tests/unit/passes/reorder_inputs_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "dft_inst.h"
#include "gather_inst.h"
#include "border_inst.h"
#include "reshape_inst.h"
#include "pass_manager.h"
#include "to_string_utils.h"

Expand Down Expand Up @@ -255,6 +256,45 @@ TEST(reorder_inputs, impl_forcing_basic_format_kernel) {
ASSERT_EQ(out_mem_ptr[7], 0.f);
}

TEST(reorder_inputs, no_add_reorder_infront_of_reshape) {
auto& engine = get_test_engine();

auto in_layout = layout{ ov::PartialShape{-1, -1, 2, 7, 7, 384}, data_types::f32, format::bfwzyx };
auto in_memory = engine.allocate_memory(layout{ ov::PartialShape{1, 2, 2, 7, 7, 384}, data_types::f32, format::bfwzyx });

auto in = generate_random_1d<float>(in_memory->count(), -10, 10);

set_values<float>(in_memory, in);

topology topology;
topology.add(input_layout("input0", in_layout));
topology.add(permute("permute", input_info("input0"), {0, 1, 3, 2, 4, 5}));
topology.add(reshape("reshape", input_info("permute"), true, {1, 14, 14, 384}, {1, 14, 14, 384}));
topology.add(eltwise("eltw", input_info("reshape"), input_info("reshape"), eltwise_mode::sum));
topology.add(reorder("reorder", input_info("eltw"), format::bfyx, data_types::f32));

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
config.set_property(ov::intel_gpu::optimize_data(true));
auto prog = program::build_program(engine, topology, config);

ASSERT_NE(prog, nullptr);
ASSERT_TRUE(has_node_with_type<reshape>(*prog));

ASSERT_TRUE(prog->get_node("reshape").can_be_optimized());
auto reshape_layout_in = prog->get_node("reshape").get_input_layouts()[0];
auto reshape_layout_out = prog->get_node("reshape").get_output_layout();

ASSERT_EQ(reshape_layout_in.format, format::bfwzyx);
ASSERT_EQ(reshape_layout_out.format, format::bfyx);

auto dep_id_of_reshape = prog->get_node("reshape").get_dependencies_ids()[0];
ASSERT_EQ(dep_id_of_reshape, "permute");

ov::PartialShape expected_out_shape{1, 14, 14, 384};
ASSERT_EQ(reshape_layout_out.get_partial_shape(), expected_out_shape);
}

// TODO Not yet implemented
//TEST(reorder_inputs, impl_forcing_conv_format_kernel) {
// auto& engine = get_test_engine();
Expand Down

0 comments on commit db145a0

Please sign in to comment.