Skip to content

Commit

Permalink
[GPU] Skip primitive if it has only shapeof users (openvinotoolkit#26648
Browse files Browse the repository at this point in the history
)

### Details:
- To skip primitive if it has shapeof users only unless it should be
executed to know the shapes such as non-zero.
- Currently such case is observed only for permute + shapeof in Qwen
model.
### Tickets:
 - CVS-146842
  • Loading branch information
yeonbok authored Sep 28, 2024
1 parent df0a12d commit 8f8344a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,47 @@
#include "strided_slice_inst.h"
#include "kv_cache_inst.h"
#include "gemm_inst.h"
#include "shape_of_inst.h"
#include "broadcast_inst.h"
#include "non_zero_inst.h"
#include "non_max_suppression_inst.h"
#include "unique_inst.hpp"
#include "program_helpers.h"

using namespace cldnn;

void mark_runtime_skippable_nodes::run(program& p) {
auto itr = p.get_processing_order().begin();

while (itr != p.get_processing_order().end()) {
auto& node = *itr++;
// Set gathers that might be skipped at runtime as can_be_optimized.
// If not set, memory dependency will not work for the nodes that are skipped at runtime
program_helpers::do_for_types<gather>(*node, [](gather_node& node){
if (node->is_type<data>() || node->is_constant())
continue;

std::function<bool(const program_node& node)> all_users_are_shape_of = [&](const program_node& node) {
if (node.is_input() || node.is_output())
return false;
for (auto& u : node.get_users()) {
if (!u->is_type<shape_of>())
return false;
}
return true;
};

if (all_users_are_shape_of(*node) &&
// primitives that should be executed to know output shapes
!node->is_type<gather_nonzero>() && !node->is_type<unique_gather>() &&
!node->is_type<non_max_suppression_gather>()) {
// always to skip, no runtime execution
node->can_be_optimized(true);
GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node->id() << " has only shape_of as users. Set can_be_optimized always"
<< std::endl;
continue;
}

program_helpers::do_for_types<gather>(*node, [](gather_node& node) {
// Check pattern
auto impl_params = node.get_kernel_impl_params();
if (node.has_fused_primitives() ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void post_input_reorder::run(program& p) {
// add a reorder if primitive's input format doesn't match implementation's input format
if (node->is_type<fully_connected>()) {
const auto fc_impl = dynamic_cast<ocl::typed_primitive_impl_ocl<fully_connected>*>(impl);
if (!fc_impl)
if (!fc_impl || node->can_be_optimized())
continue;
const auto& fc_params =
*static_cast<kernel_selector::fully_connected_params*>(fc_impl->_kernel_data.params.get());
Expand Down

0 comments on commit 8f8344a

Please sign in to comment.