Skip to content

Commit

Permalink
cpu: x64: pool: enable optimized pooling for pad > ur_w
Browse files Browse the repository at this point in the history
  • Loading branch information
bartekxk committed Oct 19, 2021
1 parent 0ec054a commit 77c71e5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 18 deletions.
50 changes: 32 additions & 18 deletions src/cpu/x64/jit_uni_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,6 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
jpp.ur_bc = 1;
jpp.ur_bc_tail = 0;
}
auto ur_w = nstl::min(jpp.ow, jpp.ur / jpp.ur_bc);
if (utils::div_up(jpp.l_pad, jpp.stride_w) > ur_w)
return status::unimplemented;
if (utils::div_up(right_pad, jpp.stride_w) > ur_w)
return status::unimplemented;

// scratchpad for c_block slice of input and/or output
using namespace memory_tracking::names;
Expand Down Expand Up @@ -1301,7 +1296,8 @@ void jit_uni_pool_kernel<isa>::generate() {

auto dt_size = jpp.dt_size;
auto shift = (isa == sse41) ? vlen : 0;
add(reg_input, dt_size * (ur_w * stride_w - lpad) * c_off - shift);
add(reg_input,
dt_size * nstl::max(0, ur_w * stride_w - lpad) * c_off - shift);
add(reg_output, dt_size * ur_w * c_off - shift);
if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
auto ishift = (isa == sse41) ? jpp.c_block / 2 : 0;
Expand Down Expand Up @@ -1344,18 +1340,25 @@ void jit_uni_pool_kernel<isa>::generate() {
auto ur_w = nstl::min(jpp.ow, jpp.ur / jpp.ur_bc);
auto ur_w_tail = jpp.ow % ur_w;

int n_oi = ow / ur_w;
const int n_oi_iterations = ow / ur_w;
int n_oi = n_oi_iterations;

int r_pad1
const int r_pad1
= calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w, kw);
if (r_pad1 > 0) n_oi--;
const int ur_stride_w = ur_w * stride_w;
const int l_pad_iterations = utils::div_up(l_pad, ur_stride_w);
const int r_pad_iterations = utils::div_up(r_pad1, ur_stride_w);

if (l_pad > 0) {
n_oi -= nstl::max(0, r_pad_iterations);

for (int i = 0; i < l_pad_iterations; ++i) {
n_oi--;
const int cur_l_pad = l_pad - i * ur_stride_w;
if (n_oi < 0 && r_pad1 > 0)
process_oi(ur_w, ur_bc, l_pad, r_pad1, with_c_tail_processing);
else
process_oi(ur_w, ur_bc, l_pad, 0, with_c_tail_processing);
process_oi(
ur_w, ur_bc, cur_l_pad, r_pad1, with_c_tail_processing);
else if (n_oi >= 0)
process_oi(ur_w, ur_bc, cur_l_pad, 0, with_c_tail_processing);
}

xor_(oi_iter, oi_iter);
Expand All @@ -1371,12 +1374,23 @@ void jit_uni_pool_kernel<isa>::generate() {
}
}

if (r_pad1 > 0 && n_oi >= 0)
process_oi(ur_w, ur_bc, 0, r_pad1, with_c_tail_processing);
if (n_oi >= 0) {
const int r_pad1_tail = r_pad1 % ur_stride_w != 0
? r_pad1 % ur_stride_w
: ur_stride_w;
for (int i = 0; i < r_pad_iterations; ++i) {
const int cur_r_pad = r_pad1_tail + ur_stride_w * i;
process_oi(ur_w, ur_bc, 0, cur_r_pad, with_c_tail_processing);
}
}

if (ur_w_tail != 0)
process_oi(
ur_w_tail, ur_bc, 0, r_pad, with_c_tail_processing, false);
if (ur_w_tail != 0) {
const int l_pad_tail = n_oi_iterations < l_pad_iterations
? l_pad % ur_stride_w
: 0;
process_oi(ur_w_tail, ur_bc, l_pad_tail, r_pad,
with_c_tail_processing, false);
}
};
Label ur_bc_tail_label, c_tail_processing_label, finish_label;

Expand Down
5 changes: 5 additions & 0 deletions tests/benchdnn/inputs/pool/shapes_2d
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ ic35_iw30ih37_ow14oh17_kw3kh4_sw2sh2
ic35_iw30ih36_ow14oh17_pw0ph1_kw3kh4_sw2sh2
ic35_iw33ih37_ow14oh17_kw6kh4_sw2sh2
ic35_iw33ih36_ow14oh17_pw0ph1_kw6kh4_sw2sh2

# Padding is bigger than ur_w
mb1ic8_ih19oh10kh15dh0sh2ph14_iw19ow10kw15dw0sw2pw14
mb1ic8_ih19oh10kh14dh0sh2ph13_iw19ow10kw14dw0sw2pw13

# With dilation
mb1ic8_ih3oh3_kh3ph1_dh2dw2
mb122ic32_ih32iw2_oh32ow2_kh3kw3_ph1pw1_dh4dw1
Expand Down

0 comments on commit 77c71e5

Please sign in to comment.