Skip to content

Commit

Permalink
Skip concat_input_order for an element which lies outside of the buff…
Browse files Browse the repository at this point in the history
…er range (#23573)
  • Loading branch information
steve-y committed Mar 21, 2024
1 parent afd444a commit a208631
Showing 1 changed file with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ namespace {

using shuffle_range = std::pair<int32_t, int32_t>;

bool can_shuffle_features(program_node& node, stream& stream) {
bool can_shuffle_features(program_node& node, program_node& concat_node, stream& stream) {
if (node.is_type<convolution>()) {
auto& conv_node = node.as<convolution>();
auto& wei_node = conv_node.weights();
if (ov::element::Type(wei_node.get_output_layout().data_type).bitwidth() < 8)
return false;

return conv_node.get_groups() == 1 &&
return conv_node.get_groups() == 1 && node.get_dependency_index(concat_node) == 0 &&
conv_node.get_deformable_groups() == 1 && !conv_node.get_transposed() &&
!conv_node.activations_zero_points_term() &&
wei_node.is_type<data>() && wei_node.is_constant() && !wei_node.is_output();
Expand All @@ -37,7 +37,7 @@ bool can_shuffle_features(program_node& node, stream& stream) {
if (ov::element::Type(wei_node.get_output_layout().data_type).bitwidth() < 8)
return false;

return wei_node.is_type<data>() && wei_node.is_constant() && !wei_node.is_output();
return node.get_dependency_index(concat_node) == 0 && wei_node.is_type<data>() && wei_node.is_constant() && !wei_node.is_output();
}

bool pass_through = false;
Expand All @@ -48,7 +48,7 @@ bool can_shuffle_features(program_node& node, stream& stream) {
if (pass_through) {
// Primitives that are feature order invariant, pass-through shuffled features to users
for (auto& user : node.get_users()) {
if (!can_shuffle_features(*user, stream))
if (!can_shuffle_features(*user, concat_node, stream))
return false;
}
return true;
Expand Down Expand Up @@ -160,7 +160,7 @@ void concat_input_order::run(program& p) {
// Check that we can fuse shuffling to users
bool can_shuffle_users = true;
for (auto user : concat_node.get_users()) {
can_shuffle_users &= can_shuffle_features(*user, p.get_stream());
can_shuffle_users &= can_shuffle_features(*user, concat_node, p.get_stream());
}

if (!along_f || !no_fusing || !correct_format || !single_format || already_aligned || !can_shuffle_users)
Expand Down

0 comments on commit a208631

Please sign in to comment.