Skip to content

Commit

Permalink
[GPU] Update weights reorder output shape for fully_connected
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Sep 19, 2023
1 parent 10dc2d8 commit 885bfde
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,6 @@ void post_optimize_weights::optimize_weights(T& node, program& p) {
!prev_node.has_fused_primitives() &&
!prev_node.as<reorder>().has_mean() &&
prev_node.as<reorder>().get_primitive()->subtract_per_feature.empty();
if (impl->is_dynamic()) {
if (weights_reorder_params->get_output_layout().compatible(prev_node.get_output_layout())) {
// if compatible, it can be reinterpreted, thus no need to reorder at build time
continue;
}
// Need to restore the original shape
auto updated_output_layout = weights_reorder_params->get_output_layout();
auto orig_rank = prev_node.get_output_layout().get_partial_shape().size();
auto weight_format_dims = format::dimension(weights_reorder_params->get_output_layout().format);
updated_output_layout.set_partial_shape(
updated_output_layout.get_tensor().get_partial_shape(orig_rank, weight_format_dims));
if (updated_output_layout != weights_reorder_params->get_output_layout())
weights_reorder_params->set_output_layout(updated_output_layout);
}
if (can_be_fused) {
// Need to update input data_type for correct merging format reorder with precision reorder
data_types input_dtype = prev_node.get_input_layouts()[0].data_type;
Expand Down
25 changes: 25 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/fully_connected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,31 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::fully_connected_impl)

fully_connected_impl() = default;

fully_connected_impl(const kernel_selector::kernel_data& kd) {
const auto& params = kd.weightsReorderParams;

if (params.is_initialized) {
// Assumption that kernel data contains already reshaped 2d weights
auto crop_to_2d = [](const ov::PartialShape& shape) {
return ov::PartialShape({shape[0], shape[1]});
};

auto weights_reorder_params = std::make_shared<WeightsReorderParams>(from_weights_tensor(params.src),
from_weights_tensor(params.dest),
params.rotate);
auto output_layout = weights_reorder_params->get_output_layout();
output_layout.set_partial_shape(crop_to_2d(output_layout.get_partial_shape()));
weights_reorder_params->set_output_layout(output_layout);

_weights_reorder_params = weights_reorder_params;
}
_kernel_data = kd;
_kernel_name = kd.kernelName;
can_reuse_memory = _kernel_data.can_reuse_memory;
}

std::unique_ptr<primitive_impl> clone() const override {
return make_unique<fully_connected_impl>(*this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3012,3 +3012,51 @@ INSTANTIATE_TEST_SUITE_P(
),
fully_connected_types_u8_f32_test::PrintToStringParamName
);

TEST(fully_connected_gpu, weights_reorder_shapes_update_test) {
auto& engine = get_test_engine();

const int32_t input_f = 3, input_b = 1, weight_b = 4;

auto input_dyn_layout = layout{ ov::PartialShape{ ov::Dimension(1, 10), input_f }, data_types::f32, format::bfyx };
auto input_data = engine.allocate_memory(layout{ ov::PartialShape{ input_b, input_f }, data_types::f32, format::bfyx });
auto weights_data = engine.allocate_memory({ ov::PartialShape{ weight_b, input_f }, data_types::f32, format::bfyx });

set_values(input_data, { -0.5f, 2.0f, 0.5f });
set_values(weights_data, { 1.5f, 1.0f, 0.5f, -1.0f, 0.0f, 0.5f, 0.5f, -0.5f, -2.0f, -0.5f, 1.0f, 1.5f });

cldnn::topology topology{
input_layout("input", input_dyn_layout),
data("weights", weights_data),
fully_connected("fc", input_info("input"), "weights")
};

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::optimize_data(true));
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
network network(engine, topology, config);
network.set_input_data("input", input_data);

auto outputs = network.execute();
ASSERT_EQ(outputs.size(), size_t(1));
ASSERT_EQ(outputs.begin()->first, "fc");

auto inst = network.get_primitive("fc");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());

ASSERT_TRUE(impl->need_weights_reorder());
auto weights_reorder_params = impl->get_weights_reorder_params();
auto out_weights_reorder_layout = weights_reorder_params->get_output_layout();
auto out_weights_reorder_pshape = out_weights_reorder_layout.get_partial_shape();
ASSERT_EQ(weights_data->get_layout().get_partial_shape(), out_weights_reorder_pshape);

auto output_prim_mem = outputs.begin()->second.get_memory();
cldnn::mem_lock<float> output_ptr (output_prim_mem, get_test_stream());

ASSERT_EQ(1.5f, output_ptr[0]);
ASSERT_EQ(0.75f, output_ptr[1]);
ASSERT_EQ(-2.25f, output_ptr[2]);
ASSERT_EQ(3.0f, output_ptr[3]);
}

0 comments on commit 885bfde

Please sign in to comment.