Skip to content

Commit

Permalink
Brgemm: new classes
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Feb 20, 2023
1 parent 3e0a7b7 commit e883075
Show file tree
Hide file tree
Showing 7 changed files with 560 additions and 269 deletions.
6 changes: 4 additions & 2 deletions src/plugins/intel_cpu/src/emitters/cpu_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_
jitters[ngraph::snippets::op::LoopBegin::get_type_info_static()] = CREATE_EMITTER(LoopBeginEmitter);
jitters[ngraph::snippets::op::LoopEnd::get_type_info_static()] = CREATE_EMITTER(LoopEndEmitter);
jitters[ov::intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmEmitter);
jitters[ov::intel_cpu::BrgemmIndependentCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmEmitter);
jitters[ov::intel_cpu::BrgemmWithCompensationsCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmWithScratchEmitter);
jitters[ov::intel_cpu::BrgemmAMXCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmWithScratchEmitter);
jitters[ov::intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_EMITTER(BrgemmCopyBEmitter);
}

Expand All @@ -165,8 +168,7 @@ code ov::intel_cpu::CPUTargetMachine::get_snippet() const {
}

ngraph::snippets::TargetMachine::opRegType ov::intel_cpu::CPUTargetMachine::get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const {
if (std::dynamic_pointer_cast<ov::intel_cpu::BrgemmCPU>(op) ||
std::dynamic_pointer_cast<ov::intel_cpu::BrgemmCopyB>(op))
if (std::dynamic_pointer_cast<ov::intel_cpu::BrgemmCopyB>(op))
return gpr2gpr;
else
return vec2vec;
Expand Down
411 changes: 252 additions & 159 deletions src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp

Large diffs are not rendered by default.

63 changes: 47 additions & 16 deletions src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,20 +361,17 @@ class StoreConvertEmitter : public MemoryEmitter {
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
};

class BrgemmEmitter : public jit_emitter {
public:
BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);

size_t get_inputs_num() const override {return 2;}
// Base class for Brgemm emitters with common interface
class BrgemmBaseEmitter : public jit_emitter {
protected:
BrgemmBaseEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);

private:
void emit_impl(const std::vector<size_t>& in,
const std::vector<size_t>& out,
const std::vector<size_t>& pool,
const std::vector<size_t>& gpr,
const ov::intel_cpu::emitter_context *emit_context) const override;

std::vector<size_t> io_data_size {};
struct brgemmCtx {
size_t M, N, K, LDA, LDB, LDC;
dnnl_data_type_t dt_in0, dt_in1;
Expand All @@ -383,14 +380,18 @@ class BrgemmEmitter : public jit_emitter {
bool is_with_comp;
float beta;
};
void initBrgemm(brgemmCtx& ctx, std::unique_ptr<brgemm_kernel_t>& brgKernel, bool use_amx) const;
void callBrgemm(brgemmCtx& ctx, std::unique_ptr<brgemm_kernel_t>& brgKernel, const void* pin0, const void* pin1, void* pout, void* wsp) const;
size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const;

size_t get_brg_idx(size_t mIdx, size_t kIdx, size_t nIdx) const;
OutputVector get_io_values(const std::shared_ptr<ov::Node>& n) const;
void init_brgemm(brgemmCtx& ctx, std::unique_ptr<brgemm_kernel_t>& brgKernel, bool use_amx) const;
virtual std::vector<size_t> init_kernel_offsets(size_t mb, size_t M_blk, size_t LDA, size_t LDC,
size_t k, size_t K0_step0, size_t K0_step1,
size_t n, size_t N0_step0, size_t N0_step1) const = 0;

void emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, const brgemmCtx& ctx,
Reg64 addr_A, Reg64 addr_B, Reg64 scratch, Reg64 addr_C,
const size_t in0_kernel_offset, const size_t in1_kernel_offset,
const size_t in2_kernel_offset, const size_t out0_kernel_offset) const;
const std::vector<Reg64>& regs, const std::vector<size_t>& offsets) const;
virtual void kernel_preparation(const brgemmCtx& ctx) const {}
virtual void kernel_call(const brgemm_kernel_t *brg_kernel, const std::vector<Reg64>& regs, const std::vector<size_t>& offsets) const = 0;
static void kernel_execute(const brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C, void *scratch, int with_comp);

static constexpr size_t BRGEMM_KERNELS_NUM = 8;
Expand All @@ -403,15 +404,45 @@ class BrgemmEmitter : public jit_emitter {
size_t N, N_blk, N_tail;
size_t brg0VnniFactor;

bool with_scratch = false;
bool with_comp = false;
std::vector<size_t> io_data_size {};

size_t load_offset_a = 0lu;
size_t load_offset_b = 0lu;
size_t load_offset_scratch = 0lu;
size_t store_offset_c = 0lu;

bool is_amx = false;
bool with_comp = false;
};

class BrgemmEmitter : public BrgemmBaseEmitter {
public:
BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);

size_t get_inputs_num() const override {return 2;}
protected:
std::vector<size_t> init_kernel_offsets(size_t mb, size_t M_blk, size_t LDA, size_t LDC,
size_t k, size_t K0_step0, size_t K0_step1,
size_t n, size_t N0_step0, size_t N0_step1) const override;
void kernel_call(const brgemm_kernel_t *brg_kernel, const std::vector<Reg64>& regs, const std::vector<size_t>& offsets) const override;
};

class BrgemmWithScratchEmitter : public BrgemmBaseEmitter {
public:
BrgemmWithScratchEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);

size_t get_inputs_num() const override {return 3;}
protected:
std::vector<size_t> init_kernel_offsets(size_t mb, size_t M_blk, size_t LDA, size_t LDC,
size_t k, size_t K0_step0, size_t K0_step1,
size_t n, size_t N0_step0, size_t N0_step1) const override;
void kernel_preparation(const brgemmCtx& ctx) const override;
void kernel_call(const brgemm_kernel_t *brg_kernel, const std::vector<Reg64>& regs, const std::vector<size_t>& offsets) const override;

private:
size_t load_offset_scratch = 0;
};


class BrgemmCopyBEmitter : public jit_emitter {
public:
BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);
Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_cpu/src/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
NGRAPH_OP(StoreConvertSaturation, ov::intel_cpu)
NGRAPH_OP(StoreConvertTruncation, ov::intel_cpu)
NGRAPH_OP(BrgemmCPU, ov::intel_cpu)
NGRAPH_OP(BrgemmIndependentCPU, ov::intel_cpu)
NGRAPH_OP(BrgemmWithCompensationsCPU, ov::intel_cpu)
NGRAPH_OP(BrgemmAMXCPU, ov::intel_cpu)
NGRAPH_OP(BrgemmCopyB, ov::intel_cpu)
#undef NGRAPH_OP

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,28 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
std::shared_ptr<ov::Node> brgemm_cpu = nullptr;
if (one_of(element_type_a, ov::element::f32)) {
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), brgemm->input_value(1),
brgemm->transposed_a(), brgemm->transposed_b(), with_comp,
brgemm->transposed_a(), brgemm->transposed_b(),
offset_a, offset_b, offset_c);
} else {
const auto layoutIn1 = ngraph::snippets::utils::get_node_output_layout(brgemm->input_value(1).get_node_shared_ptr());
const auto brgemmRepackIn1 = std::make_shared<BrgemmCopyB>(brgemm->input_value(1), element_type_a, with_comp, offset_b);
const auto buffer = std::make_shared<ngraph::snippets::op::IntermediateBuffer>(brgemmRepackIn1->output(0));

if (with_amx || with_comp) {
std::shared_ptr<ngraph::snippets::op::Buffer> scratch = nullptr;
if (with_amx) {
const auto scratch_size = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int32_t>{8 * 1024});
scratch = std::make_shared<ngraph::snippets::op::AllocationBuffer>(scratch_size, ov::element::f32);
} else if (with_comp) {
scratch = std::make_shared<ngraph::snippets::op::IntermediateBuffer>(brgemmRepackIn1->output(1));
}

brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), buffer, scratch,
brgemm->transposed_a(), brgemm->transposed_b(), with_comp,
offset_a, offset_b, offset_c);
if (with_amx) {
const auto scratch_size = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int32_t>{8 * 1024});
const auto scratch = std::make_shared<ngraph::snippets::op::AllocationBuffer>(scratch_size, ov::element::f32);
brgemm_cpu = std::make_shared<BrgemmAMXCPU>(brgemm->input_value(0), buffer, scratch,
brgemm->transposed_a(), brgemm->transposed_b(),
offset_a, offset_b, offset_c);
} else if (with_comp) {
const auto scratch = std::make_shared<ngraph::snippets::op::IntermediateBuffer>(brgemmRepackIn1->output(1));
brgemm_cpu = std::make_shared<BrgemmWithCompensationsCPU>(brgemm->input_value(0), buffer, scratch,
brgemm->transposed_a(), brgemm->transposed_b(),
offset_a, offset_b, offset_c);
} else if (one_of(element_type_a, ov::element::u8, ov::element::bf16)) {
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), buffer,
brgemm->transposed_a(), brgemm->transposed_b(), with_comp,
offset_a, offset_b, offset_c);
brgemm_cpu = std::make_shared<BrgemmIndependentCPU>(brgemm->input_value(0), buffer,
brgemm->transposed_a(), brgemm->transposed_b(),
offset_a, offset_b, offset_c);
} else {
IE_THROW() << "Invalid configuration for BRGEMM CPU";
}
Expand Down
Loading

0 comments on commit e883075

Please sign in to comment.