Skip to content

Commit

Permalink
Alexandra's comments applied
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 12, 2023
1 parent 6b0346a commit 3d5690d
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 87 deletions.
23 changes: 16 additions & 7 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@ class LinearIR::LoopManager {
struct LoopPort {
LoopPort() = default;
LoopPort(const ExpressionPort& port, bool is_incremented = true, size_t dim_idx = 0)
: expr_port(std::make_shared<ExpressionPort>(port)), is_incremented(is_incremented), dim_idx(dim_idx) {}
: expr_port(std::make_shared<ExpressionPort>(port)),
is_incremented(is_incremented),
dim_idx(dim_idx) {
OPENVINO_ASSERT(dim_idx < port.get_descriptor_ptr()->get_shape().size(),
"LoopPort dim_idx (",
dim_idx,
") must be less than the corresponding expression port shape rank (",
port.get_descriptor_ptr()->get_shape().size(),
")");
}

friend bool operator==(const LoopPort& lhs, const LoopPort& rhs);
friend bool operator!=(const LoopPort& lhs, const LoopPort& rhs);
Expand Down Expand Up @@ -93,12 +102,12 @@ class LinearIR::LoopManager {
const std::vector<T>& entries,
const std::vector<T>& exits) {
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, work_amount_increment, entries, exits);
for (auto& entry : loop_info->entry_points) {
entry.dim_idx = dim_idx;
}
for (auto& exit : loop_info->exit_points) {
exit.dim_idx = dim_idx;
}
auto set_common_dim_idx = [dim_idx](std::vector<LoopPort>& ports) {
for (auto& port : ports)
port.dim_idx = dim_idx;
};
set_common_dim_idx(loop_info->entry_points);
set_common_dim_idx(loop_info->exit_points);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
Expand Down
3 changes: 1 addition & 2 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ LinearIR::LoopManager::LoopInfo::LoopInfo(size_t work_amount, size_t increment,
}

size_t LinearIR::LoopManager::LoopInfo::get_dim_idx() const {
if (entry_points.empty())
return SIZE_MAX;
OPENVINO_ASSERT(!entry_points.empty(), "Loop info must have at least one entry point");
auto equal_dim_idxes = [&](const LinearIR::LoopManager::LoopPort& p) {
return p.dim_idx == entry_points[0].dim_idx;
};
Expand Down
31 changes: 0 additions & 31 deletions src/common/snippets/src/lowered/pass/identify_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,39 +51,8 @@ std::vector<bool> IdentifyBuffers::create_adjacency_matrix(const LinearIR& linea
}
};

auto is_buffer = [](const ExpressionPort& port) {
return ov::is_type<op::Buffer>(port.get_expr()->get_node());
};

for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) {
const auto &expr = *expr_it;
if (const auto brgemm = ov::as_type_ptr<op::Brgemm>(expr->get_node())) {
const auto consumers = expr->get_output_port_connector(0)->get_consumers();

auto buffer_it = std::find_if(consumers.begin(), consumers.end(), is_buffer);
if (buffer_it == consumers.end())
continue;
OPENVINO_ASSERT(std::count_if(consumers.begin(), consumers.end(), is_buffer) == 1, "Brgemm mustn't have more than 1 consumer buffer");

std::vector<std::shared_ptr<op::Buffer>> adjacency_buffers;
adjacency_buffers.push_back(ov::as_type_ptr<op::Buffer>(buffer_it->get_expr()->get_node()));

for (const auto& input_connector : expr->get_input_port_connectors()) {
const auto parent_node = input_connector->get_source().get_expr()->get_node();
if (const auto neighbour_buffer = ov::as_type_ptr<op::Buffer>(parent_node)) {
adjacency_buffers.push_back(neighbour_buffer);
}
}
for (auto buffer_it = adjacency_buffers.begin(); buffer_it != adjacency_buffers.end(); ++buffer_it) {
for (auto neighbour_it = std::next(buffer_it); neighbour_it != adjacency_buffers.end(); ++neighbour_it) {
const auto buffer_idx = get_buffer_idx(*buffer_it);
const auto neighbour_idx = get_buffer_idx(*neighbour_it);
adj[index(size, neighbour_idx, buffer_idx)] = adj[index(size, buffer_idx, neighbour_idx)] = true;
}
}
continue;
}

const auto& loop_end = ov::as_type_ptr<op::LoopEnd>(expr->get_node());
if (!loop_end)
continue;
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/insert_tail_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ bool InsertTailLoop::run(LinearIR& linear_ir) {
continue;

const auto loop_info = loop_manager->get_loop_info(loop_end->get_id());
if (loop_info->fst_iter_handler != nullptr) {
if (loop_info->fst_iter_handler) {
modified |= loop_info->fst_iter_handler(linear_ir, expr_it);
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ bool OptimizeLoopSingleEvaluation::run(LinearIR& linear_ir) {
return false;

bool is_modified = false;
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
if (auto loop_end = ov::as_type_ptr<op::LoopEnd>(expr_it->get()->get_node())) {
for (const auto& expr : linear_ir) {
if (auto loop_end = ov::as_type_ptr<op::LoopEnd>(expr->get_node())) {
// *1* solo vector/tail loop + empty outer loop
// => skip increments (both counter & ptr) : set evaluate_once flag
// *2* solo vector/tail loop + non-empty outer loop
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/tests/src/lowering_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ std::shared_ptr<ov::snippets::op::Subgraph>
const ov::snippets::lowered::pass::PassPipeline& lowered_pre_common,
const ov::snippets::lowered::pass::PassPipeline& lowered_post_common,
const std::shared_ptr<ov::snippets::Generator>& generator,
const std::shared_ptr<IShapeInferSnippetsFactory>& factory) {
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory) {
auto subgraph = getTokenizedSubgraph(f);
subgraph->set_generator(generator == nullptr ? std::make_shared<DummyGenerator>() : generator);
subgraph->set_master_shape(master_shape);
subgraph->set_tile_rank(2);
// Note: lowered_pipeline would have no effect on subgraph body, since it's applied on linear IR
subgraph->generate(backend_passes, lowered_pre_common, lowered_post_common, factory);
subgraph->generate(backend_passes, lowered_pre_common, lowered_post_common, shape_infer_factory);
return subgraph;
}

Expand Down
16 changes: 2 additions & 14 deletions src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,20 +803,8 @@ std::set<std::vector<element::Type>> BrgemmEmitter::get_supported_precisions(con
}

void BrgemmEmitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
std::set<size_t> unique_ids{in[0], in[1], out[0]};
size_t unique_ids_count = 3;
auto add_reg_to_unique_ids = [&](const size_t reg_number) {
unique_ids.insert(reg_number);
unique_ids_count++;
};

if (m_with_scratch) {
if (in.size() != 3)
IE_THROW() << "BRGEMM Emitter expects 3 inputs if there are compensations/wsp";
add_reg_to_unique_ids(in[2]);
}
if (unique_ids.size() != unique_ids_count) {
IE_THROW() << "BRGEMM Emitter expects that all input/output registers are unique";
if (m_with_scratch && in.size() != 3) {
IE_THROW() << "BRGEMM Emitter expects 3 inputs if there are compensations/wsp";
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,31 @@ void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vector<size_t
// During ctor call, BrgemmCopyB doesn't know his port descriptors.
// So we use port descs from source inputs
const auto element_type = get_input_element_type(0);
const auto& pshape = get_input_partial_shape(0);
validate_element_type(element_type);
// The data always store in planar shape after repacking
const auto planar_pshape = snippets::utils::get_planar_pshape(pshape, layout_input);
const auto planar_pshape = snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_input);
// data repacking output
set_output_type(0, element_type, planar_pshape);
// If compensations are needed, they are provided in 2nd output (which is used in BrgemmCPU)
if (is_with_compensations()) {
set_output_type(1, ov::element::f32, planar_pshape);
}
validate(planar_pshape, element_type);
}

void BrgemmCopyB::validate_and_infer_types() {
INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types);
const auto& element_type = get_input_element_type(0);
validate_element_type(element_type);
const auto port = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0));
const auto shape = ov::Shape(port->get_shape());
const auto& element_type = get_input_element_type(0);
const auto& planar_pshape = snippets::utils::get_planar_pshape(shape, port->get_layout());
set_output_type(0, element_type, planar_pshape);
if (is_with_compensations()) {
set_output_type(1, ov::element::f32, planar_pshape);
}
validate(planar_pshape, element_type);
}

void BrgemmCopyB::validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type) {
void BrgemmCopyB::validate_element_type(const ov::element::Type& element_type) {
OPENVINO_ASSERT(one_of(element_type, element::bf16, element::i8),
"BrgemmCopyB doesn't support element type" + element_type.get_type_name());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class BrgemmCopyB : public snippets::op::MemoryAccess {

private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_input = {});
void validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type);
void validate_element_type(const ov::element::Type& element_type);
void compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n);

Type m_type = Type::OnlyRepacking;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace intel_cpu {
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
const size_t offset_a, const size_t offset_b, const size_t offset_c,
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n, const float beta)
: Brgemm(), m_type(type) {
// We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call
set_arguments({A, B});
Expand All @@ -32,8 +32,8 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type ty
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
const size_t offset_a, const size_t offset_b, const size_t offset_scratch, const size_t offset_c,
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
: Brgemm(), m_type(type) {
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n, const float beta)
: Brgemm(), m_type(type), m_beta(beta) {
set_arguments({A, B, scratch});
set_output_size(1);
ctor_initialize(std::set<size_t>{0, 1, 2}, std::set<size_t>{0});
Expand All @@ -48,8 +48,8 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c,
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
: Brgemm(), m_type(type) {
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n, const float beta)
: Brgemm(), m_type(type), m_beta(beta) {
set_arguments({A, B});
set_output_size(1);
m_input_ports = {{0, desc_a}, {1, desc_b}};
Expand All @@ -61,8 +61,8 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type ty
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_scratch, const PortDescriptor& desc_c,
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
: Brgemm(), m_type(type) {
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n, const float beta)
: Brgemm(), m_type(type), m_beta(beta) {
set_arguments({A, B, scratch});
set_output_size(1);
m_input_ports = {{0, desc_a}, {1, desc_b}, {2, desc_scratch}};
Expand Down Expand Up @@ -136,22 +136,20 @@ std::shared_ptr<Node> BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a
check_new_args_count(this, new_args);
std::shared_ptr<BrgemmCPU> brgemm;
if (!is_with_scratchpad()) {
brgemm = std::make_shared<BrgemmCPU>(new_args.at(0), new_args.at(1), m_type,
return std::make_shared<BrgemmCPU>(new_args.at(0), new_args.at(1), m_type,
get_input_port_descriptor(0), get_input_port_descriptor(1), get_output_port_descriptor(0),
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(),
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(),
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout(),
m_M_blk, m_K_blk, m_N_blk);
m_M_blk, m_K_blk, m_N_blk, m_beta);
} else {
brgemm = std::make_shared<BrgemmCPU>(new_args.at(0), new_args.at(1), new_args.at(2), m_type,
return std::make_shared<BrgemmCPU>(new_args.at(0), new_args.at(1), new_args.at(2), m_type,
get_input_port_descriptor(0), get_input_port_descriptor(1), get_input_port_descriptor(2), get_output_port_descriptor(0),
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(),
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(),
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout(),
m_M_blk, m_K_blk, m_N_blk);
m_M_blk, m_K_blk, m_N_blk, m_beta);
}
brgemm->set_beta(get_beta());
return brgemm;
}

std::shared_ptr<BrgemmCopyB> BrgemmCPU::get_brgemm_copy() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ class BrgemmCPU : public snippets::op::Brgemm {
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_c = 0,
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0, const float beta = 0.f);
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_scratch = 0, const size_t offset_c = 0,
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0, const float beta = 0.f);
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c,
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0, const float beta = 0.f);
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_scratch, const PortDescriptor& desc_c,
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0, const float beta = 0.f);
BrgemmCPU() = default;

void validate_and_infer_types() override;
Expand Down
Loading

0 comments on commit 3d5690d

Please sign in to comment.