Skip to content

Commit

Permalink
fix nhwc indices
Browse files Browse the repository at this point in the history
  • Loading branch information
alvoron committed Aug 6, 2024
1 parent c664ca7 commit 6afe7a3
Showing 1 changed file with 6 additions and 27 deletions.
33 changes: 6 additions & 27 deletions src/plugins/intel_cpu/src/nodes/executors/acl/acl_interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,6 @@
#include "acl_utils.hpp"
#include "utils/debug_capabilities.h"

static bool getIndices(const ov::intel_cpu::MemoryDescPtr &desc, int& index_h, int& index_w) {
if (desc->hasLayoutType(ov::intel_cpu::LayoutType::ncsp)) {
index_h = 2;
index_w = 3;
return true;
} else if (desc->hasLayoutType(ov::intel_cpu::LayoutType::nspc)) {
index_h = 1;
index_w = 2;
return true;
} else { return false; }
}

bool ov::intel_cpu::ACLInterpolateExecutor::init(const InterpolateAttrs &interpolateAttrs,
const std::vector <MemoryDescPtr> &srcDescs,
const std::vector <MemoryDescPtr> &dstDescs,
Expand All @@ -27,13 +15,7 @@ bool ov::intel_cpu::ACLInterpolateExecutor::init(const InterpolateAttrs &interpo
acl_coord = arm_compute::SamplingPolicy::TOP_LEFT;
auto& out_shape = dstDescs[0]->getShape().getDims();

int index_h, index_w;
if (!getIndices(dstDescs[0], index_h, index_w)) {
DEBUG_LOG("ACL Interpolate unsupported layout: ", dstDescs[0]->serializeFormat());
return false;
}

if ((aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::pytorch_half_pixel && out_shape[index_h] > 1 && out_shape[index_w] > 1) ||
if ((aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::pytorch_half_pixel && out_shape[2] > 1 && out_shape[3] > 1) ||
aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::half_pixel) {
acl_coord = arm_compute::SamplingPolicy::CENTER;
}
Expand Down Expand Up @@ -115,11 +97,8 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration(
auto& inp_shape = srcDescs[0]->getShape().getDims();
auto& out_shape = dstDescs[0]->getShape().getDims();

int index_h, index_w;
if (!getIndices(srcDescs[0], index_h, index_w)) { return false; }

float scale_h = static_cast<float>(out_shape[index_h]) / inp_shape[index_h];
float scale_w = static_cast<float>(out_shape[index_w]) / inp_shape[index_w];
float scale_h = static_cast<float>(out_shape[2]) / inp_shape[2];
float scale_w = static_cast<float>(out_shape[3]) / inp_shape[3];
bool is_upsample = scale_h > 1 && scale_w > 1;

auto& coord_mode = interpolateAttrs.coordTransMode;
Expand Down Expand Up @@ -152,8 +131,8 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration(
return true;
}
} else if (scale_h < 1 && scale_w < 1) {
float down_scale_h = static_cast<float>(inp_shape[index_h]) / out_shape[index_h];
float down_scale_w = static_cast<float>(inp_shape[index_w]) / out_shape[index_w];
float down_scale_h = static_cast<float>(inp_shape[2]) / out_shape[2];
float down_scale_w = static_cast<float>(inp_shape[3]) / out_shape[3];
bool int_factor = down_scale_h == static_cast<int>(down_scale_h) && down_scale_w == static_cast<int>(down_scale_w);

if (int_factor && coord_mode != InterpolateCoordTransMode::align_corners &&
Expand All @@ -163,7 +142,7 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration(
}

if (int_factor && nearest_mode == InterpolateNearestMode::round_prefer_ceil &&
((out_shape[index_h] > 1 && out_shape[index_w] > 1) || coord_mode != InterpolateCoordTransMode::half_pixel)) {
((out_shape[2] > 1 && out_shape[3] > 1) || coord_mode != InterpolateCoordTransMode::half_pixel)) {
DEBUG_LOG("!upsample && int_factor && round_prefer_ceil && (out_shape > 1 || half_pixel) case is supported");
return true;
}
Expand Down

0 comments on commit 6afe7a3

Please sign in to comment.