From 5fda5037935e43d716eb359edfa6125048e386e7 Mon Sep 17 00:00:00 2001 From: alexey-varyzgin Date: Wed, 27 Jan 2021 19:52:38 +0300 Subject: [PATCH] [BF16] Deconvolution with post ops --- src/cpu/x64/gemm_bf16_convolution.cpp | 46 ++++++++++++++++ src/cpu/x64/gemm_bf16_convolution.hpp | 40 +++++++++++++- .../jit_avx512_core_bf16_1x1_convolution.cpp | 3 + .../jit_avx512_core_bf16_1x1_convolution.hpp | 22 +++++++- .../x64/jit_avx512_core_bf16_conv_kernel.cpp | 55 ++++++++++++++++++- .../x64/jit_avx512_core_bf16_conv_kernel.hpp | 35 +++++++++--- .../x64/jit_avx512_core_bf16_convolution.cpp | 2 + .../x64/jit_avx512_core_bf16_convolution.hpp | 8 ++- .../jit_avx512_core_bf16_dw_conv_kernel.cpp | 41 ++++++++++++++ .../jit_avx512_core_bf16_dw_conv_kernel.hpp | 20 ++++++- 10 files changed, 253 insertions(+), 19 deletions(-) diff --git a/src/cpu/x64/gemm_bf16_convolution.cpp b/src/cpu/x64/gemm_bf16_convolution.cpp index 7cf7394ef55..812558258c4 100644 --- a/src/cpu/x64/gemm_bf16_convolution.cpp +++ b/src/cpu/x64/gemm_bf16_convolution.cpp @@ -703,6 +703,8 @@ status_t gemm_bf16_convolution_bwd_data_t< // threads share work across mini-batch and groups const size_t work_amount = jcp.ngroups * MB; + const auto& p = pd()->attr()->post_ops_; + acc_data_t *__restrict col = scratchpad.get(key_conv_gemm_col) + (ptrdiff_t)ithr * jcp.im2col_sz; acc_data_t *__restrict acc = scratchpad.get(key_conv_gemm_acc) @@ -735,6 +737,26 @@ status_t gemm_bf16_convolution_bwd_data_t< if (jcp.im2col_sz) jit_gemm_convolution_utils::col2im_dt(jcp, col, acc); + if (p.len() > 0) { + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_weights = post_op.depthwise.weights_data; + auto depthwise_bias = post_op.depthwise.biases_data; + parallel_nd(static_cast(jcp.is) * jcp.id, [&](size_t is) { + diff_src_data_t*__restrict diff_src_arr + = diff_src + is * diff_src_os_stride; + for (int ic = 0; ic < jcp.ic; ic++) { + diff_src_arr[ic] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(diff_src_arr[ic], + depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic); + } + }); + depthwise_inj_idx++; + } + } + } + const bool is_diff_src_bf16 = diff_src_data_type == data_type::bf16; if (is_diff_src_bf16 && jcp.ngroups == 1 && jcp.nthr != 1) { @@ -800,6 +822,8 @@ status_t gemm_bf16_convolution_bwd_data_t:: const size_t work_amount = (size_t)jcp.ngroups * MB; const bool is_problem_3d = pd()->ndims() == 5; + const auto& p = pd()->attr()->post_ops_; + std::atomic st(status::success); parallel(jcp.nthr, [&](const int ithr, const int nthr) { @@ -853,6 +877,28 @@ status_t gemm_bf16_convolution_bwd_data_t:: od, os_nb * jcp.os_block, os_block); } } + + if (p.len() > 0) { + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_weights = post_op.depthwise.weights_data; + auto depthwise_bias = post_op.depthwise.biases_data; + parallel_nd(jcp.ic, [&](const int ic) { + for (int id = 0; id < jcp.id; ++id) { + acc_data_t *d_ = acc + ic * jcp.id * jcp.is + id * jcp.is; + for (int iS = 0; iS < jcp.is; ++iS) { + d_[iS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[iS], + depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic); + } + } + }); + depthwise_inj_idx++; + } + } + } + if (diff_src_data_type == data_type::bf16) { size_t spatial_size = (size_t)jcp.ih * jcp.iw * jcp.id; store_bfloat16_in_parallel((bfloat16_t *)diff_src_local, diff --git a/src/cpu/x64/gemm_bf16_convolution.hpp b/src/cpu/x64/gemm_bf16_convolution.hpp index 8c2a0c0cd6f..562ae351d1b 100644 --- a/src/cpu/x64/gemm_bf16_convolution.hpp +++ b/src/cpu/x64/gemm_bf16_convolution.hpp @@ -28,6 +28,7 @@ #include "cpu/x64/cpu_reducer.hpp" #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" #include "cpu/x64/jit_uni_eltwise_injector.hpp" +#include "cpu/ref_depthwise_injector.hpp" namespace dnnl { namespace impl { @@ -270,7 +271,7 @@ struct gemm_bf16_convolution_bwd_data_t : public primitive_t { && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(diff_src_data_type, data_type::bf16, data_type::undef, data_type::bf16, data_type::f32) - && !has_zero_dim_memory() && attr()->has_default_values(); + && !has_zero_dim_memory() && is_supported_post_ops(); if (!ok) return status::unimplemented; auto scratchpad = scratchpad_registry().registrar(); @@ -280,9 +281,42 @@ struct gemm_bf16_convolution_bwd_data_t : public primitive_t { } conv_gemm_conf_t jcp_; + + + protected: + virtual bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; - gemm_bf16_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} + gemm_bf16_convolution_bwd_data_t(const pd_t* apd) : primitive_t(apd) { + const auto& post_ops = pd()->attr()->post_ops_; + for (int i = 0; i < post_ops.len(); i++) { + auto& post_op = post_ops.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new ref_depthwise_scalar_fwd_t(post_op.depthwise.alg)); + } + } + } + + ~gemm_bf16_convolution_bwd_data_t() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } typedef typename prec_traits::type diff_dst_data_t; typedef typename prec_traits::type acc_data_t; @@ -304,6 +338,8 @@ struct gemm_bf16_convolution_bwd_data_t : public primitive_t { const memory_tracking::grantor_t &scratchpad, int MB) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + nstl::vector depthwise_injectors; }; template diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp index 65dd02d8b20..fe2a079a762 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp @@ -574,6 +574,9 @@ void jit_avx512_core_bf16_1x1_convolution_bwd_data_t< const size_t str_size = jcp.bcast_dim * max_load_per_thread; p.store_buffer = store_buffer + ithr * str_size + data_blk_off(diff_src_d, 0, 0, id, ih, iw); + + p.oc_off = ic_off_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float); + (*kernel_)(&p); if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp); }; diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp index 7abe3fedb43..f6cb299776d 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp @@ -339,8 +339,9 @@ struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t : public primitive_t { && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(diff_src_type, data_type::bf16, data_type::undef, data_type::bf16, data_type::undef) - && attr()->has_default_values() && !has_zero_dim_memory() - && set_default_formats(); + && !has_zero_dim_memory() + && set_default_formats() + && is_supported_post_ops(); if (!ok) return status::unimplemented; const convolution_desc_t *conv_d = desc(); @@ -376,6 +377,23 @@ struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t : public primitive_t { return set_default_formats_common(dat_tag, wei_tag, dat_tag); } + + bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; template diff --git a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp index 6675637e9bb..c4cd93ec783 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp @@ -996,6 +996,29 @@ void _jit_avx512_core_bf16_bwd_data_kernel::store_output(int ur_w) { if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16(); const int ic_tail = jcp.ic_tail; + int depthwise_inj_idx = 0; + const auto& p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + mov(reg_d_weights, reinterpret_cast(post_op.depthwise.weights_data)); + mov(reg_d_bias, reinterpret_cast(post_op.depthwise.biases_data)); + + add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]); + add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]); + + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + k * jcp.ur_w, k * jcp.ur_w + ur_w, reg_d_weights, reg_d_bias); + + add(reg_d_weights, jcp.ic_block * sizeof(float)); + add(reg_d_bias, jcp.ic_block * sizeof(float)); + } + + depthwise_inj_idx++; + } + } + if (jcp.dst_dt == data_type::f32) { for (int k = 0; k < jcp.nb_ic_blocking; k++) for (int j = 0; j < ur_w; j++) { @@ -1238,6 +1261,17 @@ void _jit_avx512_core_bf16_bwd_data_kernel::compute_loop( template void _jit_avx512_core_bf16_bwd_data_kernel::generate() { + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op.depthwise.alg + )); + } + } + int iw = jcp.iw; int kw = jcp.kw; int ur_w = jcp.ur_w; @@ -1424,9 +1458,26 @@ void _jit_avx512_core_bf16_bwd_data_kernel::generate() { postamble(); } +bool jit_avx512_core_bf16_bwd_data_kernel::post_ops_ok( + jit_conv_conf_t& jcp, const primitive_attr_t& attr) { + const auto& p = attr.post_ops_; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); +} + status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &diff_src_md, - memory_desc_t &weights_md, memory_desc_t &diff_dst_md, int nthreads) { + memory_desc_t &weights_md, memory_desc_t &diff_dst_md, + const primitive_attr_t& attr, int nthreads) { const memory_desc_wrapper diff_src_d(&diff_src_md); const memory_desc_wrapper weights_d(&weights_md); @@ -1582,6 +1633,8 @@ status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp, && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; if (!args_ok) return status::unimplemented; + if (!post_ops_ok(jcp, attr)) return status::unimplemented; + jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block); jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block); diff --git a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp index a71a4e9566c..54b77c74425 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp @@ -275,19 +275,30 @@ struct jit_avx512_core_bf16_fwd_kernel { template struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator { - _jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp) - : jit_generator(nullptr, ker_code_size), jcp(ajcp), bf16_emu_(nullptr) { + _jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp, + const primitive_attr_t& attr) + : jit_generator(nullptr, ker_code_size) + , jcp(ajcp) + , attr_(attr) + , bf16_emu_(nullptr) { if (!isa_has_bf16(jcp.isa)) bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_5); } - ~_jit_avx512_core_bf16_bwd_data_kernel() { delete bf16_emu_; } + ~_jit_avx512_core_bf16_bwd_data_kernel() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + + delete bf16_emu_; + } DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_bf16_bwd_data_kernel_f32) const jit_conv_conf_t &jcp; + const primitive_attr_t& attr_; private: using Vmm_down_t = @@ -364,6 +375,12 @@ struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator { Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30); Vmm vmm_wei = Vmm(31); + + reg64_t reg_d_weights = r15; + reg64_t reg_d_bias = reg_kj; + + nstl::vector*> depthwise_injectors; + bf16_emulation_t *bf16_emu_; inline void prepare_output(int ur_w); @@ -445,20 +462,20 @@ struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator { struct jit_avx512_core_bf16_bwd_data_kernel { - jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp) + jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp, const primitive_attr_t& attr) : kernel_(nullptr) { switch (ajcp.ic_block) { case 16: kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel( - ajcp); + ajcp, attr); return; case 8: kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel( - ajcp); + ajcp, attr); return; case 4: kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel( - ajcp); + ajcp, attr); return; default: assert(!"invalid channel blocking"); } @@ -468,10 +485,12 @@ struct jit_avx512_core_bf16_bwd_data_kernel { ~jit_avx512_core_bf16_bwd_data_kernel() { delete kernel_; } + static bool post_ops_ok(jit_conv_conf_t& jcp, const primitive_attr_t& attr); + static status_t init_conf(jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &diff_src_md, memory_desc_t &weights_md, memory_desc_t &diff_dst_md, - int nthreads); + const primitive_attr_t& attr, int nthreads); void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); } const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); } diff --git a/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp b/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp index 9629406f87c..8e8d70b51a7 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp @@ -554,6 +554,7 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data_3d( par_conv.filt = wht_w + kh_lo * wht_h_stride; par_conv.kh_padding = kh_len; par_conv.kd_padding = kd_len; + par_conv.oc_off = ic_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float); (*kernel_)(&par_conv); } @@ -693,6 +694,7 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data( par_conv.filt = wht_w + k_lo * wht_h_stride; par_conv.kh_padding = k_len; par_conv.iwb = iwb; + par_conv.oc_off = ic_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float); (*kernel_)(&par_conv); } diff --git a/src/cpu/x64/jit_avx512_core_bf16_convolution.hpp b/src/cpu/x64/jit_avx512_core_bf16_convolution.hpp index c95297c40cc..fdfd82f0007 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_convolution.hpp @@ -135,11 +135,12 @@ struct jit_avx512_core_bf16_convolution_bwd_data_t : public primitive_t { || expect_data_types(data_type::bf16, data_type::bf16, data_type::undef, data_type::bf16, data_type::undef)) - && attr()->has_default_values() && !has_zero_dim_memory(); + && attr()->has_default_values(primitive_attr_t::skip_mask_t::post_ops) + && !has_zero_dim_memory(); if (!ok) return status::unimplemented; status_t status = jit_avx512_core_bf16_bwd_data_kernel::init_conf( - jcp_, *desc(), diff_src_md_, weights_md_, diff_dst_md_, + jcp_, *desc(), diff_src_md_, weights_md_, diff_dst_md_, *attr(), dnnl_get_max_threads()); return status; } @@ -155,7 +156,8 @@ struct jit_avx512_core_bf16_convolution_bwd_data_t : public primitive_t { status_t init(engine_t *engine) override { CHECK(safe_ptr_assign( - kernel_, new jit_avx512_core_bf16_bwd_data_kernel(pd()->jcp_))); + kernel_, new jit_avx512_core_bf16_bwd_data_kernel( + pd()->jcp_, *pd()->attr()))); return kernel_->create_kernel(); } diff --git a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp index ea7c9d7edc1..07406a18464 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp @@ -558,6 +558,34 @@ inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::apply_filter( L(iter_exit_label); } +void jit_avx512_dw_conv_bwd_data_kernel_bf16::apply_postprocess(int ur_ch_blocks, int ur_str_) { + const auto& p = attr_.post_ops_; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + mov(reg_d_weights, reinterpret_cast(post_op.depthwise.weights_data)); + mov(reg_d_bias, reinterpret_cast(post_op.depthwise.biases_data)); + + add(reg_d_weights, ptr[this->param1 + GET_OFF(ic_off)]); + add(reg_d_bias, ptr[this->param1 + GET_OFF(ic_off)]); + + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int start_idx = get_acc_reg(ur_str_ * ch).getIdx(); + int end_idx = get_acc_reg(ur_str_ * ch + ur_str_).getIdx(); + + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + start_idx, end_idx, reg_d_weights, reg_d_bias); + + add(reg_d_weights, jcp.ch_block * sizeof(float)); + add(reg_d_bias, jcp.ch_block * sizeof(float)); + } + + depthwise_inj_idx++; + } + } +} + inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::store_dsrc( int ur_ch_blocks, int ur_str_w) { int ch_blk = jcp.ch_block; @@ -610,6 +638,7 @@ inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::loop_body( load_ddst(ur_ch_blocks, ur_w); apply_filter(ur_ch_blocks, ur_w); + apply_postprocess(ur_ch_blocks, ur_w); store_dsrc(ur_ch_blocks, ur_w); add(reg_dsrc, jcp.typesize_out * ur_w * jcp.ch_block * jcp.stride_w); @@ -631,6 +660,7 @@ inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::loop_body( load_ddst(ur_ch_blocks, ur_w); apply_filter(ur_ch_blocks, ur_w); + apply_postprocess(ur_ch_blocks, ur_w); store_dsrc(ur_ch_blocks, ur_w); add(reg_dsrc, jcp.typesize_out * ur_w * jcp.ch_block * jcp.stride_w); @@ -644,6 +674,17 @@ inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::loop_body( } void jit_avx512_dw_conv_bwd_data_kernel_bf16::generate() { + const auto& p = attr_.post_ops_; + for (int i = 0; i < p.len(); i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32( + this, + post_op.depthwise.alg + )); + } + } + preamble(); mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); diff --git a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp index 6c930bc0510..77465780dbd 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp @@ -145,8 +145,8 @@ struct jit_avx512_dw_conv_fwd_kernel_bf16 : public jit_generator { struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_bwd_data_kernel_bf16) - jit_avx512_dw_conv_bwd_data_kernel_bf16(const jit_conv_conf_t &ajcp, const primitive_attr_t&) - : jcp(ajcp), bf16_emu_(nullptr) { + jit_avx512_dw_conv_bwd_data_kernel_bf16(const jit_conv_conf_t &ajcp, const primitive_attr_t& attr) + : jcp(ajcp), attr_(attr), bf16_emu_(nullptr) { if (!isa_has_bf16(jcp.isa)) bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1, @@ -154,10 +154,18 @@ struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator { bf16_emu_reserv_5, bf16_emu_reserv_6); } - ~jit_avx512_dw_conv_bwd_data_kernel_bf16() { delete bf16_emu_; } + ~jit_avx512_dw_conv_bwd_data_kernel_bf16() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + + delete bf16_emu_; + } jit_conv_conf_t jcp; + const primitive_attr_t& attr_; + private: using reg64_t = const Xbyak::Reg64; @@ -188,6 +196,9 @@ struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator { reg64_t reg_kh = r13; reg64_t reg_kw = r14; + reg64_t reg_d_weights = r15; + reg64_t reg_d_bias = iter_kh; + Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); @@ -200,9 +211,12 @@ struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator { inline void loop_body(int ur_ch_blocks); inline void load_ddst(int ur_ch_blocks, int ur_str_w); inline void apply_filter(int ur_ch_blocks, int ur_str_w); + inline void apply_postprocess(int ur_ch_blocks, int ur_str_w); inline void store_dsrc(int ur_ch_blocks, int ur_str_w); void generate() override; + + nstl::vector*> depthwise_injectors; }; struct jit_avx512_dw_conv_bwd_weights_kernel_bf16 : public jit_generator {