Skip to content

Commit

Permalink
Updated Buffer Identification logic
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed May 11, 2023
1 parent a77c392 commit db63053
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "transformation.hpp"

#include "snippets/op/buffer.hpp"

namespace ngraph {
namespace snippets {
namespace lowered {
Expand Down Expand Up @@ -34,7 +36,7 @@ class IdentifyBuffers: public Transformation {
bool run(LinearIR& linear_ir) override;

private:
using BufferSet = std::vector<ExpressionPtr>;
using BufferSet = std::vector<std::shared_ptr<op::Buffer>>;

std::vector<bool> create_adjacency_matrix(const LinearIR& linear_ir, const BufferSet& buffers) const;
std::map<size_t, BufferSet> coloring(BufferSet& buffers, std::vector<bool>& adj);
Expand Down
134 changes: 62 additions & 72 deletions src/common/snippets/src/lowered/pass/identify_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ namespace lowered {
namespace pass {

namespace {
auto is_intermediate_buffer(const std::shared_ptr<ov::Node>& op) -> std::shared_ptr<op::Buffer> {
const auto buffer = ov::as_type_ptr<op::Buffer>(op);
return buffer && buffer->is_intermediate_memory() ? buffer : nullptr;
}

inline size_t index(size_t col_num, size_t row, size_t col) {
return row * col_num + col;
}
Expand All @@ -34,73 +29,70 @@ std::vector<bool> IdentifyBuffers::create_adjacency_matrix(const LinearIR& linea
for (size_t i = 0; i < size; ++i)
adj[index(size, i, i)] = true;

auto update_adj_matrix = [&](const std::shared_ptr<op::Buffer>& buffer, size_t buffer_index,
const std::shared_ptr<op::Buffer>& neighbour_buffer,
size_t buffer_loop_port, size_t neighbour_buffer_loop_port,
const std::vector<int64_t>& ptr_increments,
const std::vector<int64_t>& io_data_sizes) {
if (neighbour_buffer) {
// TODO: What's about finalization offsets? It's needed?
if (ptr_increments[buffer_loop_port] != ptr_increments[neighbour_buffer_loop_port] ||
io_data_sizes[buffer_loop_port] != io_data_sizes[neighbour_buffer_loop_port]) {
const auto iter = std::find(buffers.cbegin(), buffers.cend(), linear_ir.get_expr_by_node(neighbour_buffer));
NGRAPH_CHECK(iter != buffers.cend(), "Buffer wasn't find in Buffer system of Subgraph");

const size_t adj_idx = std::distance(buffers.cbegin(), iter);
adj[index(size, adj_idx, buffer_index)] = adj[index(size, buffer_index, adj_idx)] = true;
}
// < ptr_increment, finalization_offset >
using ShiftPtrParams = std::pair<int64_t, int64_t>;

auto get_buffer_idx = [&](const std::shared_ptr<op::Buffer>& buffer) {
const auto iter = std::find(buffers.cbegin(), buffers.cend(), buffer);
NGRAPH_CHECK(iter != buffers.cend(), "Buffer wasn't find in Buffer system of Subgraph");
return std::distance(buffers.cbegin(), iter);
};

auto update_adj_matrix = [&](const std::pair<std::shared_ptr<op::Buffer>, ShiftPtrParams>& buffer,
const std::pair<std::shared_ptr<op::Buffer>, ShiftPtrParams>& neighbour_buffer) {
const bool equal_ptr_params_shifting = buffer.second == neighbour_buffer.second;
const bool equal_element_type_sizes = buffer.first->get_element_type().size() == neighbour_buffer.first->get_element_type().size();
if (!equal_ptr_params_shifting || ((buffer.second.first != 0 || buffer.second.second != 0) && !equal_element_type_sizes)) {
const auto buffer_idx = get_buffer_idx(buffer.first);
const auto neighbour_idx = get_buffer_idx(neighbour_buffer.first);
adj[index(size, neighbour_idx, buffer_idx)] = adj[index(size, buffer_idx, neighbour_idx)] = true;
}
};

for (size_t buffer_idx = 0; buffer_idx < buffers.size(); ++buffer_idx) {
// Here intermediate Buffer
const auto& buffer_expr = buffers[buffer_idx];
const auto buffer = ov::as_type_ptr<op::Buffer>(buffer_expr->get_node());
const auto& buffer_tensor = buffer_expr->get_input_tensor(0);
const auto buffer_siblings = buffer_tensor->get_consumers();
for (const auto& buffer_sibling : buffer_siblings) {
const auto& sibling_expr = buffer_sibling.get_expr();
// Skip myself
if (sibling_expr == buffer_expr) {
continue;
} else if (const auto loop_end = ov::as_type_ptr<op::LoopEnd>(sibling_expr->get_node())) {
const auto loop_tds = sibling_expr->get_input_tensors();
const auto input_count = loop_end->get_input_num();
const auto output_count = loop_end->get_output_num();
const auto& ptr_increments = loop_end->get_ptr_increments();
const auto& io_data_sizes = loop_end->get_element_type_sizes();
const auto buffer_loop_port = std::distance(loop_tds.begin(), std::find(loop_tds.begin(), loop_tds.end(), buffer_tensor));

// Verify Buffers on Loop inputs:
for (size_t input_idx = 0; input_idx < input_count; ++input_idx) {
const auto& loop_in = loop_tds[input_idx]->get_source().get_expr();
if (const auto& neighbour_buffer = is_intermediate_buffer(loop_in->get_node())) {
const auto neighbour_buffer_loop_port = input_idx;
update_adj_matrix(buffer, buffer_idx, neighbour_buffer,
buffer_loop_port, neighbour_buffer_loop_port,
ptr_increments, io_data_sizes);
}
}
for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) {
const auto &expr = *expr_it;
const auto& loop_end = ov::as_type_ptr<op::LoopEnd>(expr->get_node());
if (!loop_end)
continue;

const auto input_count = loop_end->get_input_num();
const auto output_count = loop_end->get_output_num();

// Verify Buffers on Loop outputs
for (size_t output_idx = 0; output_idx < output_count; ++output_idx) {
// Skip the current Buffer
if (buffer_tensor == loop_tds[input_count + output_idx])
continue;

const auto consumer_inputs = loop_tds[input_count + output_idx]->get_consumers();
for (const auto& consumer_input : consumer_inputs) {
const auto& child_node = consumer_input.get_expr()->get_node();
if (const auto& neighbour_buffer = is_intermediate_buffer(child_node)) {
const auto neighbour_buffer_loop_port = input_count + output_idx;
update_adj_matrix(buffer, buffer_idx, neighbour_buffer,
buffer_loop_port, neighbour_buffer_loop_port,
ptr_increments, io_data_sizes);
}
}
const auto ptr_increments = loop_end->get_ptr_increments();
const auto finalization_offsets = loop_end->get_finalization_offsets();

// Buffer -> <ptr increment, finalization_offsets>
std::map<std::shared_ptr<op::Buffer>, ShiftPtrParams> buffer_neighbours;

for (size_t i = 0; i < input_count; ++i) {
const auto& parent_output = expr->get_input_tensor(i)->get_source().get_expr();
if (const auto buffer = ov::as_type_ptr<op::Buffer>(parent_output->get_node())) {
buffer_neighbours[buffer] = { ptr_increments[i], finalization_offsets[i] };
}
}
for (size_t i = 0; i < output_count; ++i) {
// The consumers of the corresponding Store ops
const auto index = input_count + i;
const auto consumer_inputs = expr->get_input_tensor(index)->get_consumers();
size_t buffer_count = 0;
size_t loop_count = 0;
for (const auto& consumer_input : consumer_inputs) {
const auto& child_node = consumer_input.get_expr()->get_node();
if (const auto buffer = ov::as_type_ptr<op::Buffer>(child_node)) {
buffer_neighbours[buffer] = { ptr_increments[index], finalization_offsets[index] };
} else if (ov::is_type<op::LoopEnd>(child_node)) {
loop_count++;
}
} else {
OPENVINO_THROW("Buffer has incorrect siblings! There can be only LoopEnds");
}
if (buffer_count > 0) {
OPENVINO_ASSERT((buffer_count == 1) && (buffer_count + loop_count == consumer_inputs.size()),
"Loop output must have not more than 1 Buffer");
}
}

for (auto buffer_it = buffer_neighbours.begin(); buffer_it != buffer_neighbours.end(); ++buffer_it) {
for (auto neighbour_it = std::next(buffer_it); neighbour_it != buffer_neighbours.end(); ++neighbour_it) {
update_adj_matrix(*buffer_it, *neighbour_it);
}
}
}
Expand Down Expand Up @@ -161,9 +153,8 @@ bool IdentifyBuffers::run(LinearIR& linear_ir) {
BufferSet buffer_exprs;

for (const auto& expr : linear_ir) {
const auto& op = expr->get_node();
if (const auto buffer = is_intermediate_buffer(op)) {
buffer_exprs.push_back(expr);
if (const auto buffer = ov::as_type_ptr<op::Buffer>(expr->get_node())) {
buffer_exprs.push_back(buffer);
}
}

Expand All @@ -176,8 +167,7 @@ bool IdentifyBuffers::run(LinearIR& linear_ir) {
for (const auto& pair : color_groups) {
const auto color = pair.first;
const auto& united_buffers = pair.second;
for (const auto& buffer_expr : united_buffers) {
const auto buffer = ov::as_type_ptr<op::Buffer>(buffer_expr->get_node());
for (const auto& buffer : united_buffers) {
buffer->set_id(color);
}
}
Expand Down

0 comments on commit db63053

Please sign in to comment.