Skip to content

Commit

Permalink
[Snippets] Deep copy interface (openvinotoolkit#20242)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanNovoselov authored Nov 15, 2023
1 parent d304be0 commit 9e7deba
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 72 deletions.
10 changes: 9 additions & 1 deletion src/common/snippets/include/snippets/lowered/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace lowered {

class LinearIR;
using ExpressionPtr = std::shared_ptr<Expression>;
using ExressionMap = std::unordered_map<Expression*, ExpressionPtr>;
class Expression : public std::enable_shared_from_this<Expression> {
friend class LinearIR;
friend class ExpressionPort;
Expand Down Expand Up @@ -55,11 +56,16 @@ class Expression : public std::enable_shared_from_this<Expression> {

std::vector<size_t> get_loop_ids() const;
void set_loop_ids(const std::vector<size_t>& loops);
virtual ExpressionPtr clone_with_new_inputs(const std::vector<PortConnectorPtr>& new_inputs,
const std::shared_ptr<Node>& new_node) const;
ExpressionPtr clone_with_new_inputs(const ExressionMap& expr_map, const std::shared_ptr<Node>& new_node) const;

protected:
Expression(const Expression& other);
// Note: The constructor initialization is private since an expression can be created only by Linear IR.
// The method must be used only by Linear IR builder of expressions!
Expression(const std::shared_ptr<Node>& n, const std::shared_ptr<IShapeInferSnippetsFactory>& factory);
void update_node_and_connectors(const std::vector<PortConnectorPtr>& new_inputs, const std::shared_ptr<Node>& new_node);

void replace_input(size_t port, PortConnectorPtr to);

Expand All @@ -80,12 +86,14 @@ class IOExpression : public Expression {

public:
enum class io_type {INPUT, OUTPUT, UNDEFINED};

ExpressionPtr clone_with_new_inputs(const std::vector<PortConnectorPtr>& new_inputs,
const std::shared_ptr<Node>& new_node) const override;
int64_t get_index() const { return m_index; }
io_type get_type() const { return m_type; }
// Result needs shapeInfer to copy shape from Parent's output to this expr input
bool needShapeInfer() const override {return m_type == io_type::OUTPUT; }
private:
IOExpression(const IOExpression& other) = default;
explicit IOExpression(const std::shared_ptr<ov::opset1::Parameter>& n, int64_t index, const std::shared_ptr<IShapeInferSnippetsFactory>& factory);
explicit IOExpression(const std::shared_ptr<ov::opset1::Result>& n, int64_t index, const std::shared_ptr<IShapeInferSnippetsFactory>& factory);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class ExpressionPort {

ExpressionPort() = default;
explicit ExpressionPort(const std::shared_ptr<Expression>& expr, Type type, size_t port);
/**
* @interface clone_with_new_expr
* @brief Creates similar Expression port, but for new expression
*/
std::shared_ptr<ExpressionPort> clone_with_new_expr(const std::shared_ptr<Expression>& new_expr) const;

std::shared_ptr<Expression> get_expr() const;
Type get_type() const { return m_type; }
Expand Down
6 changes: 4 additions & 2 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ class LinearIR {

ExpressionPtr create_expression(const std::shared_ptr<Node>& n, const std::vector<PortConnectorPtr>& inputs);

static LinearIR::container deep_copy_range(LinearIR::container::const_iterator begin, LinearIR::container::const_iterator end);
std::shared_ptr<LinearIR> clone() const;
static LinearIR::container deep_copy_range(LinearIR::container::const_iterator begin,
LinearIR::container::const_iterator end,
ExressionMap& expression_map);

const container& get_ops() const {return m_expressions; }
const io_container& get_IO_ops() const {return m_io_expressions; }
Expand Down Expand Up @@ -116,7 +119,6 @@ class LinearIR {
IShapeInferSnippets::Result shape_infer(const std::vector<VectorDimsRef>& input_shapes);
const std::shared_ptr<ShapeInferSnippetsNode>& get_shape_infer_instance() const {return m_shape_infer; }
VectorDims get_master_shape() const;
LinearIR deep_copy() const;

private:
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;
Expand Down
5 changes: 5 additions & 0 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class LinearIR::LoopManager {
LoopPort(const ExpressionPort& port, bool is_scheduled = true)
: expr_port(std::make_shared<ExpressionPort>(port)), is_incremented(is_scheduled) {}

std::shared_ptr<LoopPort> clone_with_new_expr(const ExpressionPtr& new_expr) const;

friend bool operator==(const LoopPort& lhs, const LoopPort& rhs);
friend bool operator!=(const LoopPort& lhs, const LoopPort& rhs);
friend bool operator<(const LoopPort& lhs, const LoopPort& rhs);
Expand All @@ -49,6 +51,8 @@ class LinearIR::LoopManager {
const std::vector<ExpressionPort>& entries,
const std::vector<ExpressionPort>& exits);

std::shared_ptr<LoopInfo> clone_with_new_expr(const ExressionMap& expr_map) const;

size_t work_amount = 0;
size_t increment = 0;
size_t dim_idx = 0; // The numeration begins from the end (dim_idx = 0 -> is the most inner dimension)
Expand All @@ -63,6 +67,7 @@ class LinearIR::LoopManager {
};
using LoopInfoPtr = std::shared_ptr<LoopInfo>;

std::shared_ptr<LoopManager> clone_with_new_expr(const ExressionMap& expr_map) const;
size_t add_loop_info(const LoopInfoPtr& loop);
void remove_loop_info(size_t index);
LoopInfoPtr get_loop_info(size_t index) const;
Expand Down
67 changes: 67 additions & 0 deletions src/common/snippets/src/lowered/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ Expression::Expression(const std::shared_ptr<Node>& n, const std::shared_ptr<ISh
}
}

Expression::Expression(const Expression& other) :
std::enable_shared_from_this<Expression>(other), m_source_node(other.m_source_node),
m_emitter(other.m_emitter), m_loop_ids(other.m_loop_ids), m_shapeInference(other.m_shapeInference) {
auto clone_ports_descriptors = [](const std::vector<PortDescriptorPtr>& src, std::vector<PortDescriptorPtr>& dst) {
dst.resize(src.size());
for (size_t i = 0; i < src.size(); i++)
dst[i] = src[i]->clone();
};
clone_ports_descriptors(other.m_input_port_descriptors, m_input_port_descriptors);
clone_ports_descriptors(other.m_output_port_descriptors, m_output_port_descriptors);
// Note that connectors are not filled on purpose, since you need a shared pointer to this to initialize them,
// which is not available in constructor. Also, an expression copy is rarely expected to use the same connectors.
m_input_port_connectors = {};
m_output_port_connectors = {};
}

const PortConnectorPtr& Expression::get_input_port_connector(size_t i) const {
OPENVINO_ASSERT(i < m_input_port_connectors.size(), "Failed to get input port connector: target input port must be less than input count!");
return m_input_port_connectors[i];
Expand Down Expand Up @@ -103,6 +119,50 @@ void Expression::set_loop_ids(const std::vector<size_t>& loops) {
m_loop_ids = loops;
}

void Expression::update_node_and_connectors(const std::vector<PortConnectorPtr>& new_inputs,
const std::shared_ptr<Node>& new_node) {
OPENVINO_ASSERT(m_source_node->get_type_info() == new_node->get_type_info(),
"Can't clone expression for a new node with incompatible type");
m_source_node = new_node;
OPENVINO_ASSERT(new_inputs.size() == m_input_port_descriptors.size(),
"Can't create Expression with new inputs: invalid number of input port connectors passed");
m_input_port_connectors = new_inputs;
for (size_t i = 0; i < m_input_port_descriptors.size(); i++) {
const auto& i_con = new_inputs[i];
const auto& i_port = get_input_port(i);
if (!i_con->found_consumer(i_port))
i_con->add_consumer(i_port);
}
m_output_port_connectors.resize(m_output_port_descriptors.size());
for (size_t i = 0; i < m_output_port_descriptors.size(); i++) {
m_output_port_connectors[i] = std::make_shared<PortConnector>(get_output_port(i));
}
}

ExpressionPtr Expression::clone_with_new_inputs(const std::vector<PortConnectorPtr>& new_inputs,
const std::shared_ptr<Node>& new_node) const {
const auto& expr = std::shared_ptr<Expression>(new Expression(*this));
expr->update_node_and_connectors(new_inputs, new_node);
return expr;
}

ExpressionPtr Expression::clone_with_new_inputs(const ExressionMap& expr_map,
const std::shared_ptr<Node>& new_node) const {
std::vector<PortConnectorPtr> new_inputs;
new_inputs.reserve(m_input_port_connectors.size());
for (const auto& input : m_input_port_connectors) {
const auto& src_port = input->get_source();
const auto& new_expr_it = expr_map.find(src_port.get_expr().get());
if (new_expr_it != expr_map.end()) {
const auto& new_expr = new_expr_it->second;
new_inputs.emplace_back(new_expr->get_output_port_connector(src_port.get_index()));
} else {
new_inputs.emplace_back(input);
}
}
return clone_with_new_inputs(new_inputs, new_node);
}

ExpressionPort Expression::get_input_port(size_t i) {
return ExpressionPort(this->shared_from_this(), ExpressionPort::Type::Input, i);
}
Expand Down Expand Up @@ -146,6 +206,13 @@ IOExpression::IOExpression(const std::shared_ptr<ov::opset1::Parameter>& par, in
IOExpression::IOExpression(const std::shared_ptr<ov::opset1::Result>& res, int64_t index, const std::shared_ptr<IShapeInferSnippetsFactory>& factory)
: Expression(res, factory), m_index(index), m_type{io_type::OUTPUT} {}

ExpressionPtr IOExpression::clone_with_new_inputs(const std::vector<PortConnectorPtr>& new_inputs,
const std::shared_ptr<Node>& new_node) const {
const auto& expr = std::shared_ptr<IOExpression>(new IOExpression(*this));
expr->update_node_and_connectors(new_inputs, new_node);
return expr;
}

}// namespace lowered
}// namespace snippets
}// namespace ov
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/expression_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ ExpressionPtr LinearIR::ExpressionFactory::create(const std::shared_ptr<op::Loop
const std::vector<PortConnectorPtr>& inputs,
const LinearIR& linear_ir) {
OPENVINO_ASSERT(inputs.empty(), "LoopBegin cannot have inputs");
auto expr = std::make_shared<Expression>(Expression(n, linear_ir.m_shape_infer_factory));
auto expr = std::shared_ptr<Expression>(new Expression(n, linear_ir.m_shape_infer_factory));
init_expression_inputs(expr, inputs);
create_expression_outputs(expr);
expr->validate();
Expand Down
4 changes: 4 additions & 0 deletions src/common/snippets/src/lowered/expression_port.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ namespace lowered {
ExpressionPort::ExpressionPort(const std::shared_ptr<Expression>& expr, Type type, size_t port)
: m_expr(expr), m_type(type), m_port_index(port) {}

std::shared_ptr<ExpressionPort> ExpressionPort::clone_with_new_expr(const std::shared_ptr<Expression>& new_expr) const {
return std::make_shared<ExpressionPort>(new_expr, m_type, m_port_index);
}

std::shared_ptr<Expression> ExpressionPort::get_expr() const {
const auto expr_ptr = m_expr.lock();
OPENVINO_ASSERT(expr_ptr != nullptr, "ExpressionPort has invalid expression pointer");
Expand Down
93 changes: 30 additions & 63 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ LinearIR::LinearIR(const std::shared_ptr<ov::Model>& model, const std::shared_pt
m_shape_infer = std::make_shared<LIRShapeInfer>(m_expressions, m_io_expressions);
}

std::shared_ptr<LinearIR> LinearIR::clone() const {
auto cloned = std::make_shared<LinearIR>();
cloned->m_config = m_config;

ExressionMap expression_map;
cloned->m_expressions = deep_copy_range(m_expressions.cbegin(), m_expressions.cend(), expression_map);
for (const auto& expr : cloned->m_expressions) {
cloned->m_node2expression_map[expr->get_node()] = expr;
if (const auto& io = std::dynamic_pointer_cast<IOExpression>(expr))
cloned->m_io_expressions.push_back(io);
}

cloned->m_loop_manager = m_loop_manager->clone_with_new_expr(expression_map);
// It's Ok to share shapeInfer factory ptr, since the factory doesn't depend on LIR in any way
cloned->m_shape_infer_factory = m_shape_infer_factory;
cloned->m_shape_infer = std::make_shared<LIRShapeInfer>(cloned->m_expressions, cloned->m_io_expressions);
return cloned;
}

ExpressionPtr LinearIR::create_expression(const std::shared_ptr<Node>& n, const std::shared_ptr<ov::Model>& model) {
return ExpressionFactory::build(n, *this, model);
}
Expand Down Expand Up @@ -99,80 +118,28 @@ void LinearIR::serialize(const std::string& xml, const std::string& bin) const {
ov::pass::Serialize(xml, bin).run_on_model(tmp_model);
}

LinearIR::container LinearIR::deep_copy_range(LinearIR::container::const_iterator begin, LinearIR::container::const_iterator end) {
auto deep_clone_ports = [](std::vector<PortDescriptorPtr>& ports) {
for (auto& port : ports) { port = port->clone(); }
};
LinearIR::container LinearIR::deep_copy_range(LinearIR::container::const_iterator begin,
LinearIR::container::const_iterator end,
ExressionMap& expression_map) {
OPENVINO_ASSERT(expression_map.empty(), "deep_copy_range expects empty expression_map as an input");
LinearIR::container result;
NodeVector original_nodes;
for (auto it = begin; it != end; it++)
original_nodes.push_back((*it)->get_node());
ngraph::NodeMap node_map;
OPENVINO_SUPPRESS_DEPRECATED_START
ngraph::clone_nodes(original_nodes, node_map);
OPENVINO_SUPPRESS_DEPRECATED_END
for (auto it = begin; it != end; it++) {
// copy by value, so result shared_pointer point to new objects
Expression new_expr = **it;
new_expr.m_source_node = node_map[(*it)->get_node().get()];
deep_clone_ports(new_expr.m_input_port_descriptors);
deep_clone_ports(new_expr.m_output_port_descriptors);
result.emplace_back(std::make_shared<Expression>(new_expr));
}
return result;
}

LinearIR LinearIR::deep_copy() const {
// todo: implement the same functionality using standard copy constructor
auto clone_ports_descriptors = [](std::vector<PortDescriptorPtr>& ports) {
std::for_each(ports.begin(), ports.end(), [](PortDescriptorPtr& pd) { pd = pd->clone(); });
};
const auto& original_lir = *this;
LinearIR new_lir;
new_lir.m_config = original_lir.m_config;
new_lir.m_shape_infer = original_lir.m_shape_infer;
NodeVector original_nodes;
original_nodes.reserve(original_lir.m_expressions.size());
std::unordered_map<PortConnectorPtr, PortConnectorPtr> connectors_map;
for (const auto& orig_expr : original_lir) {
original_nodes.push_back(orig_expr->get_node());
const auto& copy_expr = ExpressionFactory::shallow_copy(orig_expr);
clone_ports_descriptors(copy_expr->m_input_port_descriptors);
clone_ports_descriptors(copy_expr->m_output_port_descriptors);

for (auto& orig_con : copy_expr->m_output_port_connectors) {
const auto& copy_source = copy_expr->get_output_port(orig_con->get_source().get_index());
const auto& copy_con = std::make_shared<PortConnector>(copy_source);
connectors_map[orig_con] = copy_con;
orig_con = copy_con;
}
for (size_t i = 0; i < copy_expr->get_input_count(); i++) {
const auto& copy_connector = connectors_map[copy_expr->get_input_port_connector(i)];
const auto& copy_consumer = copy_expr->get_input_port(i);
copy_connector->add_consumer(copy_consumer);
copy_expr->replace_input(i, copy_connector);
}

if (auto io_expr = std::dynamic_pointer_cast<IOExpression>(copy_expr))
new_lir.m_io_expressions.push_back(io_expr);
new_lir.m_expressions.push_back(copy_expr);
}
// node_map and expr_map map original node pointer (expression) to a new pointer (expression)
ngraph::NodeMap node_map;
OPENVINO_SUPPRESS_DEPRECATED_START
ngraph::clone_nodes(original_nodes, node_map);
OPENVINO_SUPPRESS_DEPRECATED_END
new_lir.m_node2expression_map.clear();
for (const auto& copy_expr : new_lir.m_expressions) {
copy_expr->m_source_node = node_map[copy_expr->m_source_node.get()];
new_lir.m_node2expression_map[copy_expr->m_source_node] = copy_expr;

for (auto it = begin; it != end; it++) {
const auto& expr = *it;
const auto& new_expr = expr->clone_with_new_inputs(expression_map, node_map[expr->get_node().get()]);
result.push_back(new_expr);
expression_map[expr.get()] = new_expr;
}
new_lir.m_loop_manager = std::make_shared<LoopManager>();
// It's Ok to share shapeInfer factory, since LIR doesn't change it
new_lir.m_shape_infer_factory = m_shape_infer_factory;
// Note: shapeInfer stores expression pointers. we re-create it, so shape inference is performed on cloned exprs.
new_lir.m_shape_infer = std::make_shared<LIRShapeInfer>(new_lir.m_expressions, new_lir.m_io_expressions);
return new_lir;
return result;
}

void LinearIR::debug_print(bool tds_as_pointers) const {
Expand Down
42 changes: 40 additions & 2 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@
namespace ov {
namespace snippets {
namespace lowered {
using LoopManager = LinearIR::LoopManager;
using LoopPort = LoopManager::LoopPort;
using LoopInfo = LoopManager::LoopInfo;

std::shared_ptr<LoopPort> LoopPort::clone_with_new_expr(const ExpressionPtr& new_expr) const {
auto new_loop_port = std::make_shared<LoopPort>(*this);
new_loop_port->expr_port = expr_port->clone_with_new_expr(new_expr);
return new_loop_port;
}

LinearIR::LoopManager::LoopInfo::LoopInfo(size_t work_amount, size_t increment, size_t dim_idx,
const std::vector<ExpressionPort>& entries, const std::vector<ExpressionPort>& exits)
LoopInfo::LoopInfo(size_t work_amount, size_t increment, size_t dim_idx,
const std::vector<ExpressionPort>& entries, const std::vector<ExpressionPort>& exits)
: work_amount(work_amount), increment(increment), dim_idx(dim_idx), outer_splited_loop(false) {
entry_points.reserve(entries.size());
exit_points.reserve(exits.size());
Expand All @@ -27,6 +36,35 @@ LinearIR::LoopManager::LoopInfo::LoopInfo(size_t work_amount, size_t increment,
exit_points.emplace_back(port);
}

std::shared_ptr<LoopInfo> LoopInfo::clone_with_new_expr(const ExressionMap& expr_map) const {
auto clone_loop_ports = [&expr_map](const std::vector<LoopPort>& port_points) {
std::vector<LoopPort> cloned_port_points;
cloned_port_points.reserve(port_points.size());
for (const auto& p : port_points) {
const auto& expr = p.expr_port->get_expr().get();
OPENVINO_ASSERT(expr_map.count(expr), "Can't clone LoopInfo: old expression is not in the map");
const auto& new_expr = expr_map.at(expr);
cloned_port_points.emplace_back(*p.clone_with_new_expr(new_expr));
}
return cloned_port_points;
};
const auto& new_entry_points = clone_loop_ports(entry_points);
const auto& new_exit_points = clone_loop_ports(exit_points);

auto new_loop_info = std::make_shared<LoopInfo>(work_amount, increment, dim_idx, new_entry_points, new_exit_points);
new_loop_info->outer_splited_loop = outer_splited_loop;

return new_loop_info;
}

std::shared_ptr<LoopManager> LoopManager::clone_with_new_expr(const ExressionMap& expr_map) const {
auto new_loop_manager = std::make_shared<LoopManager>();
for (const auto& id_info : m_map)
new_loop_manager->m_map.insert({id_info.first, id_info.second->clone_with_new_expr(expr_map)});
new_loop_manager->next_id = next_id;
return new_loop_manager;
}

bool operator==(const LinearIR::LoopManager::LoopPort& lhs, const LinearIR::LoopManager::LoopPort& rhs) {
if (&lhs == &rhs)
return true;
Expand Down
Loading

0 comments on commit 9e7deba

Please sign in to comment.