Skip to content

Commit

Permalink
src: cpu: dw_conv: width offset correction in depthwise convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
Alok Bakshi committed Aug 15, 2019
1 parent 2fbc8ba commit 6b9d412
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 20 deletions.
39 changes: 29 additions & 10 deletions src/cpu/jit_avx512_core_bf16_dw_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,13 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_step_unroll(

const int iw_block = ow_block * jcp.stride_w;
const int right_border = jcp.iw - iw_block;
const int r_pad = jcp.r_pad;

const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);

/* preamble count for number of cascaded LOAD + FMA operation */
const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
const bool is_last_block = (unroll_w + ow_block == jcp.ow);

/* LOAD initial input registers, then cascade LOADs and FMAs*/
for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
Expand All @@ -555,6 +557,13 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_step_unroll(
int off_input = (c - pad_offset) * jcp.ch_block;
if (off_input < 0 && unroll_w == jcp.ow)
continue;

const bool over_steps_bdry = true
&& is_last_block
&& (c - pad_offset + r_pad > right_border);
if (over_steps_bdry)
continue;

Zmm zmm_input = get_input_reg(c);
vpmovzxwd(zmm_input,
ptr[reg_tmp_input + off_input * jcp.typesize_in]);
Expand All @@ -565,6 +574,13 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_step_unroll(
int off_input = (overlap + c - pad_offset) * jcp.ch_block;
if (off_input < 0 || overlap + c + l_pad > right_border)
continue;

const bool over_steps_bdry = true
&& is_last_block
&& (overlap + c - pad_offset + r_pad > right_border);
if (over_steps_bdry)
continue;

Zmm zmm_input = get_input_reg(overlap + c);
vpmovzxwd(zmm_input,
ptr[reg_tmp_input + off_input * jcp.typesize_in]);
Expand All @@ -579,6 +595,12 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_step_unroll(
|| io_overlap - jcp.l_pad >= right_border)
continue;

const bool over_steps_bdry = true
&& is_last_block
&& (io_overlap - jcp.l_pad + jcp.r_pad > right_border);
if (over_steps_bdry)
continue;

Zmm zmm_input = get_input_reg(io_overlap - l_pad);
Zmm zmm_acc = get_acc_reg(i_kw);
if (isa_has_bf16(jcp.isa))
Expand Down Expand Up @@ -811,11 +833,7 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::
int ow = jcp.ow;
int pad_offset = 0;
int l_pad = jcp.l_pad;

/* Calculate effective padding */
int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
+ (jcp.kw - 1) * (jcp.dilate_w + 1)
- (jcp.iw + jcp.l_pad - 1));
int r_pad = jcp.r_pad;

/* Is this strictly defined by:
* -code-size (?)
Expand All @@ -826,8 +844,9 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::
int unroll_w_tail = 0;
int unroll_w = 0;
int unroll_w_trips = 0;
const bool do_unroll_w = jcp.ow > max_unroll_w;

if (jcp.ow > max_unroll_w) {
if (do_unroll_w) {
unroll_w = nstl::min(block_size, jcp.ow);
unroll_w_trips = ow / unroll_w;
/* calculate tail */
Expand All @@ -845,8 +864,7 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::
}
}
} else {
unroll_w = jcp.ow;
unroll_w_trips = nstl::max(1, ow / unroll_w);
unroll_w_tail = jcp.ow;
}
if (jcp.with_bias) {
Label skip_load_bias;
Expand Down Expand Up @@ -876,7 +894,7 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::
add(reg_filter_baddr, reg_kh_offset);

/* compute left padded block */
if (l_pad) {
if (l_pad && do_unroll_w) {
compute_h_loop(unroll_w, l_pad, 0, 0);
add(reg_output_baddr, unroll_w * ch_offset * jcp.typesize_in);
add(reg_input_baddr,
Expand Down Expand Up @@ -910,7 +928,8 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::

/* compute right padded block */
if (unroll_w_tail) {
compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail);
compute_h_loop(unroll_w_tail, l_pad, pad_offset,
jcp.ow - unroll_w_tail);
}
}

Expand Down
40 changes: 30 additions & 10 deletions src/cpu/jit_uni_dw_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,13 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(

const int iw_block = ow_block * jcp.stride_w;
const int right_border = jcp.iw - iw_block;
const int r_pad = jcp.r_pad;

const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);

/* preamble count for number of cascaded LOAD + FMA operation */
const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
const bool is_last_block = (unroll_w + ow_block == jcp.ow);

/* LOAD initial input registers, then cascade LOADs and FMAs*/
for (int r = 0; r < reg_repeats; ++r) {
Expand All @@ -561,6 +563,13 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
= ((c - pad_offset) * reg_repeats + r) * simd_w;
if (off_input < 0 && unroll_w == jcp.ow)
continue;

const bool over_steps_bdry = true
&& is_last_block
&& (c - pad_offset + r_pad > right_border);
if (over_steps_bdry)
continue;

Vmm vmm_input
= get_input_reg((c % jcp.kw) * reg_repeats + r);
uni_vmovups(vmm_input,
Expand All @@ -574,6 +583,13 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
* simd_w;
if (off_input < 0 || overlap + c + l_pad > right_border)
continue;

const bool over_steps_bdry = true
&& is_last_block
&& (overlap + c - pad_offset + r_pad > right_border);
if (over_steps_bdry)
continue;

Vmm vmm_input = get_input_reg(
((overlap + c) % jcp.kw) * reg_repeats + r);
uni_vmovups(vmm_input,
Expand All @@ -588,6 +604,13 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
if (io_overlap - l_pad < 0
|| io_overlap - jcp.l_pad >= right_border)
continue;

const bool over_steps_bdry = true
&& is_last_block
&& (io_overlap - jcp.l_pad + jcp.r_pad > right_border);
if (over_steps_bdry)
continue;

Vmm vmm_input = get_input_reg(
((io_overlap - l_pad) % jcp.kw) * reg_repeats + r);
Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r);
Expand Down Expand Up @@ -843,11 +866,7 @@ jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
int ow = jcp.ow;
int pad_offset = 0;
int l_pad = jcp.l_pad;

/* Calculate effective padding */
int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
+ (jcp.kw - 1) * (jcp.dilate_w + 1)
- (jcp.iw + jcp.l_pad - 1));
int r_pad = jcp.r_pad;

/* Is this strictly defined by:
* -code-size (?)
Expand All @@ -858,8 +877,9 @@ jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
int unroll_w_tail = 0;
int unroll_w = 0;
int unroll_w_trips = 0;
const bool do_unroll_w = jcp.ow > max_unroll_w;

if (jcp.ow > max_unroll_w) {
if (do_unroll_w) {
unroll_w = nstl::min(block_size, jcp.ow);
unroll_w_trips = ow / unroll_w;
/* calculate tail */
Expand All @@ -877,8 +897,7 @@ jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
}
}
} else {
unroll_w = jcp.ow;
unroll_w_trips = nstl::max(1, ow / unroll_w);
unroll_w_tail = jcp.ow;
}
if (jcp.with_bias) {
Label skip_load_bias;
Expand Down Expand Up @@ -908,7 +927,7 @@ jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
add(reg_filter_baddr, reg_kh_offset);

/* compute left padded block */
if (l_pad) {
if (l_pad && do_unroll_w) {
compute_h_loop(unroll_w, l_pad, 0, 0);
add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
add(reg_input_baddr,
Expand Down Expand Up @@ -942,7 +961,8 @@ jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {

/* compute right padded block */
if (unroll_w_tail) {
compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail);
compute_h_loop(unroll_w_tail, l_pad, pad_offset,
jcp.ow - unroll_w_tail);
}
}

Expand Down

0 comments on commit 6b9d412

Please sign in to comment.