Skip to content

Commit

Permalink
[GPU] Update weights reorder output shape for fully_connected
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Sep 18, 2023
1 parent 10dc2d8 commit 5e0a65c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,6 @@ void post_optimize_weights::optimize_weights(T& node, program& p) {
!prev_node.has_fused_primitives() &&
!prev_node.as<reorder>().has_mean() &&
prev_node.as<reorder>().get_primitive()->subtract_per_feature.empty();
if (impl->is_dynamic()) {
if (weights_reorder_params->get_output_layout().compatible(prev_node.get_output_layout())) {
// if compatible, it can be reinterpreted, thus no need to reorder at build time
continue;
}
// Need to restore the original shape
auto updated_output_layout = weights_reorder_params->get_output_layout();
auto orig_rank = prev_node.get_output_layout().get_partial_shape().size();
auto weight_format_dims = format::dimension(weights_reorder_params->get_output_layout().format);
updated_output_layout.set_partial_shape(
updated_output_layout.get_tensor().get_partial_shape(orig_rank, weight_format_dims));
if (updated_output_layout != weights_reorder_params->get_output_layout())
weights_reorder_params->set_output_layout(updated_output_layout);
}
if (can_be_fused) {
// Need to update input data_type for correct merging format reorder with precision reorder
data_types input_dtype = prev_node.get_input_layouts()[0].data_type;
Expand Down
25 changes: 25 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/fully_connected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,31 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::fully_connected_impl)

fully_connected_impl() = default;

fully_connected_impl(const kernel_selector::kernel_data& kd) {
const auto& params = kd.weightsReorderParams;

if (params.is_initialized) {
// Assumption that kernel data contains already reshaped 2d weights
auto crop_to_2d = [](const ov::PartialShape& shape) {
return ov::PartialShape({shape[0], shape[1]});
};

auto weights_reorder_params = std::make_shared<WeightsReorderParams>(from_weights_tensor(params.src),
from_weights_tensor(params.dest),
params.rotate);
auto output_layout = weights_reorder_params->get_output_layout();
output_layout.set_partial_shape(crop_to_2d(output_layout.get_partial_shape()));
weights_reorder_params->set_output_layout(output_layout);

_weights_reorder_params = weights_reorder_params;
}
_kernel_data = kd;
_kernel_name = kd.kernelName;
can_reuse_memory = _kernel_data.can_reuse_memory;
}

std::unique_ptr<primitive_impl> clone() const override {
return make_unique<fully_connected_impl>(*this);
}
Expand Down

0 comments on commit 5e0a65c

Please sign in to comment.