Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BF16] Deconvolution with post ops #31

Merged
merged 1 commit into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/cpu/x64/gemm_bf16_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,8 @@ status_t gemm_bf16_convolution_bwd_data_t<
// threads share work across mini-batch and groups
const size_t work_amount = jcp.ngroups * MB;

const auto& p = pd()->attr()->post_ops_;

acc_data_t *__restrict col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
+ (ptrdiff_t)ithr * jcp.im2col_sz;
acc_data_t *__restrict acc = scratchpad.get<acc_data_t>(key_conv_gemm_acc)
Expand Down Expand Up @@ -735,6 +737,26 @@ status_t gemm_bf16_convolution_bwd_data_t<
if (jcp.im2col_sz)
jit_gemm_convolution_utils::col2im_dt<acc_data_t>(jcp, col, acc);

if (p.len() > 0) {
int depthwise_inj_idx = 0;
for (int i = 0; i < p.len(); i++) {
auto &post_op = p.entry_[i];
if (post_op.is_depthwise()) {
auto depthwise_weights = post_op.depthwise.weights_data;
auto depthwise_bias = post_op.depthwise.biases_data;
parallel_nd(static_cast<size_t>(jcp.is) * jcp.id, [&](size_t is) {
diff_src_data_t*__restrict diff_src_arr
= diff_src + is * diff_src_os_stride;
for (int ic = 0; ic < jcp.ic; ic++) {
diff_src_arr[ic] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(diff_src_arr[ic],
depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic);
}
});
depthwise_inj_idx++;
}
}
}

const bool is_diff_src_bf16 = diff_src_data_type == data_type::bf16;

if (is_diff_src_bf16 && jcp.ngroups == 1 && jcp.nthr != 1) {
Expand Down Expand Up @@ -800,6 +822,8 @@ status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
const size_t work_amount = (size_t)jcp.ngroups * MB;
const bool is_problem_3d = pd()->ndims() == 5;

const auto& p = pd()->attr()->post_ops_;

std::atomic<status_t> st(status::success);

parallel(jcp.nthr, [&](const int ithr, const int nthr) {
Expand Down Expand Up @@ -853,6 +877,28 @@ status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
od, os_nb * jcp.os_block, os_block);
}
}

if (p.len() > 0) {
int depthwise_inj_idx = 0;
for (int i = 0; i < p.len(); i++) {
auto &post_op = p.entry_[i];
if (post_op.is_depthwise()) {
auto depthwise_weights = post_op.depthwise.weights_data;
auto depthwise_bias = post_op.depthwise.biases_data;
parallel_nd(jcp.ic, [&](const int ic) {
for (int id = 0; id < jcp.id; ++id) {
acc_data_t *d_ = acc + ic * jcp.id * jcp.is + id * jcp.is;
for (int iS = 0; iS < jcp.is; ++iS) {
d_[iS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[iS],
depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic);
}
}
});
depthwise_inj_idx++;
}
}
}

if (diff_src_data_type == data_type::bf16) {
size_t spatial_size = (size_t)jcp.ih * jcp.iw * jcp.id;
store_bfloat16_in_parallel((bfloat16_t *)diff_src_local,
Expand Down
40 changes: 38 additions & 2 deletions src/cpu/x64/gemm_bf16_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "cpu/x64/cpu_reducer.hpp"
#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
#include "cpu/x64/jit_uni_eltwise_injector.hpp"
#include "cpu/ref_depthwise_injector.hpp"

namespace dnnl {
namespace impl {
Expand Down Expand Up @@ -270,7 +271,7 @@ struct gemm_bf16_convolution_bwd_data_t : public primitive_t {
&& set_default_alg_kind(alg_kind::convolution_direct)
&& expect_data_types(diff_src_data_type, data_type::bf16,
data_type::undef, data_type::bf16, data_type::f32)
&& !has_zero_dim_memory() && attr()->has_default_values();
&& !has_zero_dim_memory() && is_supported_post_ops();
if (!ok) return status::unimplemented;

auto scratchpad = scratchpad_registry().registrar();
Expand All @@ -280,9 +281,42 @@ struct gemm_bf16_convolution_bwd_data_t : public primitive_t {
}

conv_gemm_conf_t jcp_;


protected:
virtual bool is_supported_post_ops() const {
const auto &p = this->attr()->post_ops_;
if (p.len() > 1)
return false;

auto all_post_ops_supported = [&]() {
bool ok = true;

for (int i = 0; i < p.len(); i++) {
ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise);
}
return ok;
};

return all_post_ops_supported();
}
};

gemm_bf16_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
gemm_bf16_convolution_bwd_data_t(const pd_t* apd) : primitive_t(apd) {
const auto& post_ops = pd()->attr()->post_ops_;
for (int i = 0; i < post_ops.len(); i++) {
auto& post_op = post_ops.entry_[i];
if (post_op.is_depthwise()) {
depthwise_injectors.push_back(new ref_depthwise_scalar_fwd_t(post_op.depthwise.alg));
}
}
}

~gemm_bf16_convolution_bwd_data_t() {
for (auto inj : depthwise_injectors)
delete inj;
depthwise_injectors.clear();
}

typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
typedef typename prec_traits<data_type::f32>::type acc_data_t;
Expand All @@ -304,6 +338,8 @@ struct gemm_bf16_convolution_bwd_data_t : public primitive_t {
const memory_tracking::grantor_t &scratchpad, int MB) const;

const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }

nstl::vector<ref_depthwise_scalar_fwd_t*> depthwise_injectors;
};

template <data_type_t diff_wei_data_type>
Expand Down
3 changes: 3 additions & 0 deletions src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,9 @@ void jit_avx512_core_bf16_1x1_convolution_bwd_data_t<
const size_t str_size = jcp.bcast_dim * max_load_per_thread;
p.store_buffer = store_buffer + ithr * str_size
+ data_blk_off(diff_src_d, 0, 0, id, ih, iw);

p.oc_off = ic_off_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float);

(*kernel_)(&p);
if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp);
};
Expand Down
22 changes: 20 additions & 2 deletions src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,9 @@ struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t : public primitive_t {
&& set_default_alg_kind(alg_kind::convolution_direct)
&& expect_data_types(diff_src_type, data_type::bf16,
data_type::undef, data_type::bf16, data_type::undef)
&& attr()->has_default_values() && !has_zero_dim_memory()
&& set_default_formats();
&& !has_zero_dim_memory()
&& set_default_formats()
&& is_supported_post_ops();
if (!ok) return status::unimplemented;

const convolution_desc_t *conv_d = desc();
Expand Down Expand Up @@ -376,6 +377,23 @@ struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t : public primitive_t {

return set_default_formats_common(dat_tag, wei_tag, dat_tag);
}

bool is_supported_post_ops() const {
const auto &p = this->attr()->post_ops_;
if (p.len() > 1)
return false;

auto all_post_ops_supported = [&]() {
bool ok = true;

for (int i = 0; i < p.len(); i++) {
ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise);
}
return ok;
};

return all_post_ops_supported();
}
};

template <cpu_isa_t isa, typename conv_t>
Expand Down
55 changes: 54 additions & 1 deletion src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,29 @@ void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::store_output(int ur_w) {
if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16();
const int ic_tail = jcp.ic_tail;

int depthwise_inj_idx = 0;
const auto& p = attr_.post_ops_;
for (int i = 0; i < p.len(); i++) {
auto& post_op = p.entry_[i];
if (post_op.is_depthwise()) {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));

add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]);
add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]);

for (int k = 0; k < jcp.nb_ic_blocking; k++) {
depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
k * jcp.ur_w, k * jcp.ur_w + ur_w, reg_d_weights, reg_d_bias);

add(reg_d_weights, jcp.ic_block * sizeof(float));
add(reg_d_bias, jcp.ic_block * sizeof(float));
}

depthwise_inj_idx++;
}
}

if (jcp.dst_dt == data_type::f32) {
for (int k = 0; k < jcp.nb_ic_blocking; k++)
for (int j = 0; j < ur_w; j++) {
Expand Down Expand Up @@ -1238,6 +1261,17 @@ void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::compute_loop(

template <typename Vmm>
void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::generate() {
const auto &p = attr_.post_ops_;
for (int i = 0; i < p.len(); i++) {
auto &post_op = p.entry_[i];
if (post_op.is_depthwise()) {
depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
this,
post_op.depthwise.alg
));
}
}

int iw = jcp.iw;
int kw = jcp.kw;
int ur_w = jcp.ur_w;
Expand Down Expand Up @@ -1424,9 +1458,26 @@ void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::generate() {
postamble();
}

bool jit_avx512_core_bf16_bwd_data_kernel::post_ops_ok(
jit_conv_conf_t& jcp, const primitive_attr_t& attr) {
const auto& p = attr.post_ops_;

auto all_post_ops_supported = [&]() {
bool ok = true;

for (int i = 0; i < p.len(); i++) {
ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise);
}
return ok;
};

return all_post_ops_supported();
}

status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp,
const convolution_desc_t &cd, memory_desc_t &diff_src_md,
memory_desc_t &weights_md, memory_desc_t &diff_dst_md, int nthreads) {
memory_desc_t &weights_md, memory_desc_t &diff_dst_md,
const primitive_attr_t& attr, int nthreads) {

const memory_desc_wrapper diff_src_d(&diff_src_md);
const memory_desc_wrapper weights_d(&weights_md);
Expand Down Expand Up @@ -1582,6 +1633,8 @@ status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp,
&& jcp.oc <= weights_d.padded_dims()[with_groups + 0];
if (!args_ok) return status::unimplemented;

if (!post_ops_ok(jcp, attr)) return status::unimplemented;

jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);

Expand Down
35 changes: 27 additions & 8 deletions src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,19 +275,30 @@ struct jit_avx512_core_bf16_fwd_kernel {
template <typename Vmm>
struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator {

_jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp)
: jit_generator(nullptr, ker_code_size), jcp(ajcp), bf16_emu_(nullptr) {
_jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp,
const primitive_attr_t& attr)
: jit_generator(nullptr, ker_code_size)
, jcp(ajcp)
, attr_(attr)
, bf16_emu_(nullptr) {
if (!isa_has_bf16(jcp.isa))
bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1,
bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_scratch,
bf16_emu_reserv_4, bf16_emu_reserv_5);
}

~_jit_avx512_core_bf16_bwd_data_kernel() { delete bf16_emu_; }
~_jit_avx512_core_bf16_bwd_data_kernel() {
for (auto inj : depthwise_injectors)
delete inj;
depthwise_injectors.clear();

delete bf16_emu_;
}

DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_bf16_bwd_data_kernel_f32)

const jit_conv_conf_t &jcp;
const primitive_attr_t& attr_;

private:
using Vmm_down_t =
Expand Down Expand Up @@ -364,6 +375,12 @@ struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator {
Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30);

Vmm vmm_wei = Vmm(31);

reg64_t reg_d_weights = r15;
reg64_t reg_d_bias = reg_kj;

nstl::vector<jit_uni_depthwise_injector_f32<avx512_common>*> depthwise_injectors;

bf16_emulation_t *bf16_emu_;

inline void prepare_output(int ur_w);
Expand Down Expand Up @@ -445,20 +462,20 @@ struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator {

struct jit_avx512_core_bf16_bwd_data_kernel {

jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp)
jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp, const primitive_attr_t& attr)
: kernel_(nullptr) {
switch (ajcp.ic_block) {
case 16:
kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Zmm>(
ajcp);
ajcp, attr);
return;
case 8:
kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Ymm>(
ajcp);
ajcp, attr);
return;
case 4:
kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Xmm>(
ajcp);
ajcp, attr);
return;
default: assert(!"invalid channel blocking");
}
Expand All @@ -468,10 +485,12 @@ struct jit_avx512_core_bf16_bwd_data_kernel {

~jit_avx512_core_bf16_bwd_data_kernel() { delete kernel_; }

static bool post_ops_ok(jit_conv_conf_t& jcp, const primitive_attr_t& attr);

static status_t init_conf(jit_conv_conf_t &jcp,
const convolution_desc_t &cd, memory_desc_t &diff_src_md,
memory_desc_t &weights_md, memory_desc_t &diff_dst_md,
int nthreads);
const primitive_attr_t& attr, int nthreads);
void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); }
const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); }

Expand Down
2 changes: 2 additions & 0 deletions src/cpu/x64/jit_avx512_core_bf16_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data_3d(
par_conv.filt = wht_w + kh_lo * wht_h_stride;
par_conv.kh_padding = kh_len;
par_conv.kd_padding = kd_len;
par_conv.oc_off = ic_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float);

(*kernel_)(&par_conv);
}
Expand Down Expand Up @@ -693,6 +694,7 @@ void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data(
par_conv.filt = wht_w + k_lo * wht_h_stride;
par_conv.kh_padding = k_len;
par_conv.iwb = iwb;
par_conv.oc_off = ic_idx * (is_dsrc_layout_nxc ? 1 : jcp.ic_block) * sizeof(float);

(*kernel_)(&par_conv);
}
Expand Down
Loading