Skip to content

Commit

Permalink
cpu: x64: binary injector: add partial broadcast offset caching
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun authored and usstq committed Nov 7, 2022
1 parent 6df930d commit 0d12402
Show file tree
Hide file tree
Showing 35 changed files with 774 additions and 304 deletions.
8 changes: 4 additions & 4 deletions src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ jit_brdgmm_kernel_base_t::jit_brdgmm_kernel_base_t(const brgemm_t &abrd)
= {broadcasting_strategy_t::scalar,
broadcasting_strategy_t::per_oc};
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(vmm_b().getIdx()), r14, r15, preserve_gpr,
preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec),
GET_OFF(data_C_ptr_), dst_md_wrapper,
static_cast<size_t>(n_vlen_tail()), k_mask,
static_cast<size_t>(vmm_b().getIdx()), r14, r15, r13,
preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
dst_md_wrapper, static_cast<size_t>(n_vlen_tail()), k_mask,
use_exact_tail_scalar_bcast};
const binary_injector::static_params_t bsp {
this->param1, enabled_bcast_strategy, rhs_sp};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator {
broadcasting_strategy_t::no_broadcast};
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(Xbyak::Zmm(1).getIdx()), this->r14,
this->r15, preserve_gpr, preserve_vmm,
this->r15, this->r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
dst_md_wrapper, static_cast<size_t>(brg.ldb_tail),
ld_tail_mask, use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
broadcasting_strategy_t::no_broadcast};
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(Vmm(1).getIdx()), this->r14, this->r15,
preserve_gpr, preserve_vmm,
this->r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
dst_md_wrapper, static_cast<size_t>(brg.ldb_tail),
ld_tail_mask, use_exact_tail_scalar_bcast};
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/gemm_bf16_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2021 Intel Corporation
* Copyright 2019-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -111,7 +111,7 @@ gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd)
static constexpr size_t tail_size = 0;
static constexpr bool use_exact_tail_scalar_bcast = false;
const binary_injector::rhs_arg_static_params_t rhs_sp {
helper_vmm_idx, r13, r14, preserve_gpr,
helper_vmm_idx, r13, r14, r15, preserve_gpr,
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast};
Expand Down
855 changes: 625 additions & 230 deletions src/cpu/x64/injectors/jit_uni_binary_injector.cpp

Large diffs are not rendered by default.

128 changes: 97 additions & 31 deletions src/cpu/x64/injectors/jit_uni_binary_injector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops,
* stored inside rhs_addr_reg.
* @param rhs_helper_reg - gpr register used as helper for calculations during data
* loading phase.
* @param rhs_addr_cache_reg - gpr register used for caching part of calculated
* offset.
* @param preserve_gpr_helpers - determines whether gpr registers specified above
* should be preserved (pushed to stack and poped back afterwords) between
* compute_vector_range calls.
Expand All @@ -105,40 +107,46 @@ bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops,
struct rhs_arg_static_params_t {
rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
const Xbyak::Reg64 &rhs_addr_reg,
const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers,
const Xbyak::Reg64 &rhs_helper_reg,
const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
bool preserve_vmm_helper, std::size_t abi_param_offset,
const memory_desc_wrapper &dst_d, std::size_t tail_size = 0u,
bool use_exact_tail_scalar_bcast = false);
rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
const Xbyak::Reg64 &rhs_addr_reg,
const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers,
const Xbyak::Reg64 &rhs_helper_reg,
const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
bool preserve_vmm_helper, std::size_t abi_param_offset,
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
std::size_t tail_size = 0u,
bool use_exact_tail_scalar_bcast = false);
rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
const Xbyak::Reg64 &rhs_addr_reg,
const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers,
const Xbyak::Reg64 &rhs_helper_reg,
const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
bool preserve_vmm_helper, std::size_t abi_param_offset,
const memory_desc_wrapper &dst_d, std::size_t tail_size,
const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0);
rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
const Xbyak::Reg64 &rhs_addr_reg,
const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers,
const Xbyak::Reg64 &rhs_helper_reg,
const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
bool preserve_vmm_helper, std::size_t abi_param_offset,
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0);
rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
const Xbyak::Reg64 &rhs_addr_reg,
const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers,
const Xbyak::Reg64 &rhs_helper_reg,
const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
bool preserve_vmm_helper, std::size_t abi_param_offset,
const memory_desc_wrapper &dst_d, std::size_t tail_size,
const Xbyak::Opmask &tail_opmask, const Xbyak::Reg64 &reg_tail_size,
bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0);
rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
const Xbyak::Reg64 &rhs_addr_reg,
const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers,
const Xbyak::Reg64 &rhs_helper_reg,
const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
bool preserve_vmm_helper, std::size_t abi_param_offset,
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
Expand All @@ -151,6 +159,7 @@ struct rhs_arg_static_params_t {
mutable std::size_t rhs_dt_helper_vmm_idx = 0;
Xbyak::Reg64 rhs_addr_reg;
Xbyak::Reg64 rhs_helper_reg;
Xbyak::Reg64 rhs_addr_cache_reg;
bool preserve_gpr_helpers;
bool preserve_vmm_helper;
std::size_t abi_param_offset;
Expand All @@ -167,7 +176,8 @@ struct rhs_arg_static_params_t {
private:
rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
const Xbyak::Reg64 &rhs_addr_reg,
const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers,
const Xbyak::Reg64 &rhs_helper_reg,
const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
bool preserve_vmm_helper, std::size_t abi_param_offset,
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
Expand Down Expand Up @@ -375,7 +385,8 @@ class jit_uni_binary_injector_t {
Xbyak::Address prepare_rhs_arg_addr(std::size_t vmm_idx,
std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op,
const rhs_arg_dynamic_params_t &rhs_arg_params,
const broadcasting_strategy_t rhs_broadcasting_strategy) const;
const broadcasting_strategy_t rhs_broadcasting_strategy,
bool is_first) const;
/*
* Loads data and applies particular binary operation.
*/
Expand Down Expand Up @@ -403,69 +414,124 @@ class jit_uni_binary_injector_t {
const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
int vmm_idx, const Xbyak::Reg64 &addr_reg,
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const;
void calculate_no_broadcast(Xbyak::Address addr, std::size_t offset,
const Xbyak::Reg64 &out_reg) const;
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
bool is_first) const;
void calculate_no_broadcast_base(
Xbyak::Address addr, const Xbyak::Reg64 &out_reg) const;
void calculate_no_broadcast_partial(const std::size_t offset,
const Xbyak::Reg64 &out_reg, std::size_t elem_size_bytes) const;

void append_oc_offset(
const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
int vmm_idx, const Xbyak::Reg64 &addr_reg,
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const;
void calculate_oc_ncsp(
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
bool is_first) const;
void calculate_oc_ncsp_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_oc_blocked(
void calculate_oc_ncsp_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_oc_blocked_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_oc_nspc(
void calculate_oc_blocked_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_oc_nspc_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_oc_cspn(
void calculate_oc_nspc_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_oc_cspn_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_oc_cspn_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;

void append_mb_sp_offset(
const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
int vmm_idx, const Xbyak::Reg64 &addr_reg,
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const;
void calculate_mb_sp_ncsp(
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
bool is_first) const;
void calculate_mb_sp_ncsp_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_mb_sp_blocked(
void calculate_mb_sp_ncsp_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_mb_sp_blocked_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_mb_sp_nspc(
void calculate_mb_sp_blocked_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_mb_sp_nspc_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_mb_sp_cspn(
void calculate_mb_sp_nspc_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_mb_sp_cspn_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_mb_sp_cspn_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;

void append_mb_w_offset(
const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
int vmm_idx, const Xbyak::Reg64 &addr_reg,
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const;
void calculate_mb_w_ncsp(
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
bool is_first) const;
void calculate_mb_w_ncsp_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_mb_w_blocked(
void calculate_mb_w_ncsp_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_mb_w_blocked_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_mb_w_nspc(
void calculate_mb_w_blocked_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_mb_w_nspc_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_mb_w_cspn(
void calculate_mb_w_nspc_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_mb_w_cspn_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_mb_w_cspn_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;

void append_w_offset(
const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
int vmm_idx, const Xbyak::Reg64 &addr_reg,
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const;
void calculate_w_ncsp(
const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
bool is_first) const;
void calculate_w_ncsp_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_w_blocked(
void calculate_w_ncsp_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_w_blocked_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_w_nspc(
void calculate_w_blocked_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_w_nspc_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_w_cspn(
void calculate_w_nspc_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;
void calculate_w_cspn_base(
const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
void calculate_w_cspn_partial(const dim_t *strides,
const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
std::size_t elem_size_bytes) const;

template <typename T>
typename std::enable_if<std::is_same<T, Xbyak::Zmm>::value
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32(
const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;

rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14,
preserve_gpr, preserve_vmm,
r15, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx2_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jit_avx2_conv_fwd_kernel_f32::jit_avx2_conv_fwd_kernel_f32(
const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;

rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14,
preserve_gpr, preserve_vmm,
r15, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jit_avx512_common_1x1_conv_kernel::jit_avx512_common_1x1_conv_kernel(
static constexpr bool use_exact_tail_scalar_bcast = true;

const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
r14, r15, preserve_gpr, preserve_vmm,
r14, r15, r12, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, k_load_dim_mask,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_common_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ _jit_avx512_common_conv_fwd_kernel<Vmm>::_jit_avx512_common_conv_fwd_kernel(
static constexpr bool use_exact_tail_scalar_bcast = false;

const binary_injector::rhs_arg_static_params_t rhs_args_static_params {
helper_vmm_idx, reg_tmp, r15, preserve_gpr, preserve_vmm,
helper_vmm_idx, reg_tmp, r15, r14, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, postops_mask,
use_exact_tail_scalar_bcast};
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ jit_avx512_core_amx_1x1_fwd_kernel_t::jit_avx512_core_amx_1x1_fwd_kernel_t(
using namespace binary_injector;
const auto &rhs_addr_reg = bin_injector_helper_reg_1;
const auto &rhs_helper_reg = bin_injector_helper_reg_2;
const auto &rhs_addr_cache_reg = bin_injector_helper_reg_3;
static constexpr bool preserve_gpr = false;
static constexpr bool preserve_vmm = false;
const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
static constexpr bool use_exact_tail_scalar_bcast = true;

const rhs_arg_static_params_t rhs_arg_static_params {31, rhs_addr_reg,
rhs_helper_reg, preserve_gpr, preserve_vmm,
rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, ktail_mask,
use_exact_tail_scalar_bcast};
Expand Down Expand Up @@ -146,7 +147,8 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::interleave_store() {
const injector_utils::conditional_register_preserve_guard_t
cond_register_guard(jcp.with_binary, this,
{bin_injector_helper_reg_1,
bin_injector_helper_reg_2});
bin_injector_helper_reg_2,
bin_injector_helper_reg_3});
const int wsp_row_offset = jcp.typesize_acc
* (osb * jcp.nb_oc_blocking * jcp.max_width * jcp.oc_block
+ ocb * jcp.max_width * jcp.oc_block
Expand Down
Loading

0 comments on commit 0d12402

Please sign in to comment.