Skip to content

Commit

Permalink
BrgemmCopyB: updated classes
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Feb 10, 2023
1 parent 7f6a72b commit fa1bb56
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 153 deletions.
3 changes: 2 additions & 1 deletion src/plugins/intel_cpu/src/emitters/cpu_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_
jitters[ov::intel_cpu::BrgemmWithCompensationsCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmWithScratchEmitter);
jitters[ov::intel_cpu::BrgemmAMXCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmWithScratchEmitter);
jitters[ov::intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_EMITTER(BrgemmCopyBEmitter);
jitters[ov::intel_cpu::BrgemmCopyBWithCompensations::get_type_info_static()] = CREATE_EMITTER(BrgemmCopyBWithCompensationsEmitter);
}

size_t ov::intel_cpu::CPUTargetMachine::get_lanes() const {
Expand All @@ -166,7 +167,7 @@ code ov::intel_cpu::CPUTargetMachine::get_snippet() const {
}

ngraph::snippets::TargetMachine::opRegType ov::intel_cpu::CPUTargetMachine::get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const {
if (std::dynamic_pointer_cast<ov::intel_cpu::BrgemmCopyB>(op))
if (std::dynamic_pointer_cast<ov::intel_cpu::BrgemmCopyBBase>(op))
return gpr2gpr;
else
return vec2vec;
Expand Down
276 changes: 180 additions & 96 deletions src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp

Large diffs are not rendered by default.

41 changes: 34 additions & 7 deletions src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,14 +443,14 @@ class BrgemmWithScratchEmitter : public BrgemmBaseEmitter {
size_t load_offset_scratch = 0;
};


class BrgemmCopyBEmitter : public jit_emitter {
// Base class for BrgemmCopyB emitters with common interface
class BrgemmCopyBBaseEmitter : public jit_emitter {
public:
BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);
BrgemmCopyBBaseEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);

size_t get_inputs_num() const override {return 2;}
size_t get_inputs_num() const override {return 1;}

private:
protected:
void emit_impl(const std::vector<size_t>& in,
const std::vector<size_t>& out,
const std::vector<size_t>& pool,
Expand All @@ -460,11 +460,18 @@ class BrgemmCopyBEmitter : public jit_emitter {
void init_brgemm_copy(std::unique_ptr<matmul::jit_brgemm_matmul_copy_b_t>& kernel,
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K,
bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const;
void emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp,
size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const;
void emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel,
const std::vector<Reg64>& regs, const std::vector<size_t>& offsets,
size_t N, size_t K) const;

virtual void kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, const std::vector<Reg64>& regs, const std::vector<size_t>& offsets) const = 0;
virtual std::vector<size_t> init_kernel_offsets(size_t nb, size_t N_blk, size_t brgemmVNNIFactor, size_t data_size) const = 0;

static void execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, const void* comp, size_t N, size_t K);

inline void data_ptr(Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) const;
inline void push_value(size_t value, size_t index, size_t gpr_size) const;

std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> kernel;

ov::element::Type brgemm_prc_in0, brgemm_prc_in1;
Expand All @@ -476,6 +483,26 @@ class BrgemmCopyBEmitter : public jit_emitter {

size_t in_offset = 0lu;
size_t out_offset = 0lu;
};

class BrgemmCopyBEmitter : public BrgemmCopyBBaseEmitter {
public:
BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);

protected:
void kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, const std::vector<Reg64>& regs, const std::vector<size_t>& offsets) const override;
std::vector<size_t> init_kernel_offsets(size_t nb, size_t N_blk, size_t brgemmVNNIFactor, size_t data_size) const override;
};

class BrgemmCopyBWithCompensationsEmitter : public BrgemmCopyBBaseEmitter {
public:
BrgemmCopyBWithCompensationsEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);

protected:
void kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, const std::vector<Reg64>& regs, const std::vector<size_t>& offsets) const override;
std::vector<size_t> init_kernel_offsets(size_t nb, size_t N_blk, size_t brgemmVNNIFactor, size_t data_size) const override;

private:
size_t comp_offset = 0lu;
};

Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
NGRAPH_OP(BrgemmWithCompensationsCPU, ov::intel_cpu)
NGRAPH_OP(BrgemmAMXCPU, ov::intel_cpu)
NGRAPH_OP(BrgemmCopyB, ov::intel_cpu)
NGRAPH_OP(BrgemmCopyBWithCompensations, ov::intel_cpu)
#undef NGRAPH_OP

return opset;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
offset_a, offset_b, offset_c);
} else {
const auto layoutIn1 = ngraph::snippets::utils::get_node_output_layout(brgemm->input_value(1).get_node_shared_ptr());
const auto brgemmRepackIn1 = std::make_shared<BrgemmCopyB>(brgemm->input_value(1), element_type_a, with_comp, offset_b);
std::shared_ptr<ov::Node> brgemmRepackIn1 = nullptr;
if (with_comp) {
brgemmRepackIn1 = std::make_shared<BrgemmCopyBWithCompensations>(brgemm->input_value(1), element_type_a, offset_b);
} else {
brgemmRepackIn1 = std::make_shared<BrgemmCopyB>(brgemm->input_value(1), element_type_a, offset_b);
}
const auto buffer = std::make_shared<ngraph::snippets::op::IntermediateBuffer>(brgemmRepackIn1->output(0));

if (with_amx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,23 @@
using namespace std;
using namespace ov;

intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const bool with_comp,
const size_t offset_in, const size_t offset_out0, const size_t offset_out1)
: ngraph::snippets::op::MemoryAccess({x}), m_with_comp(with_comp), m_src_type(src_type) {
intel_cpu::BrgemmCopyBBase::BrgemmCopyBBase(const Output<Node>& x, const element::Type src_type,
const size_t offset_in, const size_t offset_out)
: ngraph::snippets::op::MemoryAccess({x}), m_src_type(src_type) {
set_input_port_descriptor({0, offset_in}, 0);
set_output_port_descriptor({0, offset_out0}, 0);
if (with_comp) {
set_output_port_descriptor({0, offset_out1}, 1);
set_output_size(2);
} else {
set_output_size(1);
}
constructor_validate_and_infer_types();
set_output_port_descriptor({0, offset_out}, 0);
}

bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes);
MemoryAccess::visit_attributes(visitor);
visitor.on_attribute("with_comp", m_with_comp);
visitor.on_attribute("src_type", m_src_type);
return true;
}

void intel_cpu::BrgemmCopyB::validate_and_infer_types() {
INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types);
void intel_cpu::BrgemmCopyBBase::validate_and_infer_types() {
INTERNAL_OP_SCOPE(BrgemmCopyBBase_validate_and_infer_types);

const auto element_type = get_input_element_type(0);
NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8),
"BrgemmCopyB doesn't support element type" + element_type.get_type_name());
"BrgemmCopyBBase doesn't support element type" + element_type.get_type_name());

const auto pshape = ngraph::snippets::utils::get_port_planar_shape(input_value(0));
if (pshape.is_dynamic()) {
set_output_type(0, element_type, ov::PartialShape{ov::Dimension::dynamic()});
if (m_with_comp) {
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()});
}
return;
}

Expand All @@ -58,16 +40,60 @@ void intel_cpu::BrgemmCopyB::validate_and_infer_types() {

set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, brgemmVNNIFactor)),
ov::Dimension(rnd_up(N, N_blk))});
if (m_with_comp) {
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, N_blk))});
}
}

bool intel_cpu::BrgemmCopyBBase::visit_attributes(AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(BrgemmCopyBBase_visit_attributes);
MemoryAccess::visit_attributes(visitor);
visitor.on_attribute("src_type", m_src_type);
return true;
}

intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type src_type,
const size_t offset_in, const size_t offset_out)
: BrgemmCopyBBase(x, src_type, offset_in, offset_out) {
set_output_size(1);
constructor_validate_and_infer_types();
}

std::shared_ptr<Node> intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs);
INTERNAL_OP_SCOPE(BrgemmCopyB_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<BrgemmCopyB>(new_args.at(0), m_src_type, m_with_comp,
return std::make_shared<BrgemmCopyB>(new_args.at(0), m_src_type,
get_offset_in(),
get_offset_out(),
m_with_comp ? get_offset_comp() : 0);
get_offset_out());
}

intel_cpu::BrgemmCopyBWithCompensations::BrgemmCopyBWithCompensations(const Output<Node>& x, const element::Type src_type,
const size_t offset_in, const size_t offset_out0, const size_t offset_out1)
: BrgemmCopyBBase(x, src_type, offset_in, offset_out0) {
set_output_port_descriptor({0, offset_out1}, 1);
set_output_size(2);
constructor_validate_and_infer_types();
}

void intel_cpu::BrgemmCopyBWithCompensations::validate_and_infer_types() {
INTERNAL_OP_SCOPE(BrgemmCopyBWithCompensations_validate_and_infer_types);
BrgemmCopyBBase::validate_and_infer_types();

const auto pshape = ngraph::snippets::utils::get_port_planar_shape(input_value(0));
if (pshape.is_dynamic()) {
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()});
return;
}

const auto shape = pshape.get_shape();
const auto N = *shape.rbegin();
const auto N_blk = get_input_element_type(0) == element::bf16 ? 32 : 64;

set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, N_blk))});
}

std::shared_ptr<Node> intel_cpu::BrgemmCopyBWithCompensations::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(BrgemmCopyBWithCompensations_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<BrgemmCopyBWithCompensations>(new_args.at(0), m_src_type,
get_offset_in(),
get_offset_out(),
get_offset_comp());
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,64 @@ namespace ov {
namespace intel_cpu {

/**
* @interface BrgemmCopyB
* @brief The operation for data repacking of Brgemm with input non-fp32 precisions
* @interface BrgemmCopyBBase
* @brief The base class with the common interface for data repacking of Brgemm with input non-fp32 precisions
* @ingroup snippets
*/
class BrgemmCopyB : public ngraph::snippets::op::MemoryAccess {
class BrgemmCopyBBase : public ngraph::snippets::op::MemoryAccess {
public:
OPENVINO_OP("BrgemmCopyB", "SnippetsOpset", MemoryAccess);
BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const bool with_comp = false,
const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu);
BrgemmCopyB() = default;
OPENVINO_OP("BrgemmCopyBBase", "SnippetsOpset", MemoryAccess);
BrgemmCopyBBase() = default;

size_t get_offset_in() const { return get_input_port_descriptor(0).m_offset; }
size_t get_offset_out() const { return get_output_port_descriptor(0).m_offset; }
size_t get_offset_comp() const { return get_output_port_descriptor(1).m_offset; }

element::Type get_src_element_type() const { return m_src_type; }
bool is_with_comp() const { return m_with_comp; }

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool has_evaluate() const override { return false; }

protected:
BrgemmCopyBBase(const Output<Node>& x, const element::Type src_type,
const size_t offset_in = 0lu, const size_t offset_out = 0lu);

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override { return nullptr; };

element::Type m_src_type; // src element type of the corresponding BRGEMM (first input)
};

/**
* @interface BrgemmCopyB
* @brief The operation for data repacking of Brgemm with input non-fp32 precisions without compensations (doesn't have 2nd output)
* @ingroup snippets
*/
class BrgemmCopyB : public BrgemmCopyBBase {
public:
OPENVINO_OP("BrgemmCopyB", "SnippetsOpset", BrgemmCopyBBase);
BrgemmCopyB(const Output<Node>& x, const element::Type src_type,
const size_t offset_in = 0lu, const size_t offset_out = 0lu);
BrgemmCopyB() = default;

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};

private:
bool m_with_comp = false;
element::Type m_src_type; // src element type of the corresponding BRGEMM
/**
* @interface BrgemmCopyBWithCompensations
* @brief The operation for data repacking of Brgemm with input non-fp32 precisions with compensations (has 2nd output)
* @ingroup snippets
*/
class BrgemmCopyBWithCompensations : public BrgemmCopyBBase {
public:
OPENVINO_OP(" BrgemmCopyBWithCompensations", "SnippetsOpset", BrgemmCopyBBase);
BrgemmCopyBWithCompensations(const Output<Node>& x, const element::Type src_type,
const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu);
BrgemmCopyBWithCompensations() = default;

size_t get_offset_comp() const { return get_output_port_descriptor(1).m_offset; }

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};

} // namespace intel_cpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ std::shared_ptr<Node> BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a

// ============================= BrgemmWithRepackingCPU ==============================

std::shared_ptr<BrgemmCopyB> BrgemmWithRepackingCPU::get_brgemm_copy() const {
std::shared_ptr<BrgemmCopyBBase> BrgemmWithRepackingCPU::get_brgemm_copy() const {
if (const auto buffer = ov::as_type_ptr<ngraph::snippets::op::IntermediateBuffer>(get_input_node_shared_ptr(1))) {
return ov::as_type_ptr<BrgemmCopyB>(buffer->get_input_node_shared_ptr(0));
return ov::as_type_ptr<BrgemmCopyBBase>(buffer->get_input_node_shared_ptr(0));
}
return nullptr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class BrgemmWithRepackingCPU : public ngraph::snippets::op::Brgemm {
OPENVINO_OP("BrgemmWithRepackingCPU", "SnippetsOpset", ngraph::snippets::op::Brgemm);
BrgemmWithRepackingCPU() = default;

std::shared_ptr<BrgemmCopyB> get_brgemm_copy() const;
std::shared_ptr<BrgemmCopyBBase> get_brgemm_copy() const;

protected:
void validate_output();
Expand Down

0 comments on commit fa1bb56

Please sign in to comment.