Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alvoron committed Aug 7, 2024
1 parent 40d2a14 commit 7e1c28b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ bool ov::intel_cpu::ACLInterpolateExecutor::init(const InterpolateAttrs &interpo
const std::vector <MemoryDescPtr> &dstDescs,
const dnnl::primitive_attr &attr) {
aclInterpolateAttrs = interpolateAttrs;
InterpolateExecutor::init(aclInterpolateAttrs, srcDescs, dstDescs, attr);
acl_coord = arm_compute::SamplingPolicy::TOP_LEFT;
auto& out_shape = dstDescs[0]->getShape().getDims();

if ((aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::pytorch_half_pixel && out_shape[2] > 1 && out_shape[3] > 1) ||
if ((aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::pytorch_half_pixel && out_shape[index_h] > 1 && out_shape[index_w] > 1) ||
aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::half_pixel) {
acl_coord = arm_compute::SamplingPolicy::CENTER;
}
Expand Down Expand Up @@ -97,8 +96,8 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration(
auto& inp_shape = srcDescs[0]->getShape().getDims();
auto& out_shape = dstDescs[0]->getShape().getDims();

float scale_h = static_cast<float>(out_shape[2]) / inp_shape[2];
float scale_w = static_cast<float>(out_shape[3]) / inp_shape[3];
float scale_h = static_cast<float>(out_shape[index_h]) / inp_shape[index_h];
float scale_w = static_cast<float>(out_shape[index_w]) / inp_shape[index_w];
bool is_upsample = scale_h > 1 && scale_w > 1;

auto& coord_mode = interpolateAttrs.coordTransMode;
Expand Down Expand Up @@ -131,8 +130,8 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration(
return true;
}
} else if (scale_h < 1 && scale_w < 1) {
float down_scale_h = static_cast<float>(inp_shape[2]) / out_shape[2];
float down_scale_w = static_cast<float>(inp_shape[3]) / out_shape[3];
float down_scale_h = static_cast<float>(inp_shape[index_h]) / out_shape[index_h];
float down_scale_w = static_cast<float>(inp_shape[index_w]) / out_shape[index_w];
bool int_factor = down_scale_h == static_cast<int>(down_scale_h) && down_scale_w == static_cast<int>(down_scale_w);

if (int_factor && coord_mode != InterpolateCoordTransMode::align_corners &&
Expand All @@ -142,7 +141,7 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration(
}

if (int_factor && nearest_mode == InterpolateNearestMode::round_prefer_ceil &&
((out_shape[2] > 1 && out_shape[3] > 1) || coord_mode != InterpolateCoordTransMode::half_pixel)) {
((out_shape[index_h] > 1 && out_shape[index_w] > 1) || coord_mode != InterpolateCoordTransMode::half_pixel)) {
DEBUG_LOG("!upsample && int_factor && round_prefer_ceil && (out_shape > 1 || half_pixel) case is supported");
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class ACLInterpolateExecutor : public InterpolateExecutor {
}

private:
static const size_t index_h = 2;
static const size_t index_w = 3;

impl_desc_type implType = impl_desc_type::acl;
InterpolateAttrs aclInterpolateAttrs;
arm_compute::SamplingPolicy acl_coord;
Expand Down

0 comments on commit 7e1c28b

Please sign in to comment.