Skip to content

Commit

Permalink
MemoryAccess update 2
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Mar 24, 2023
1 parent 5a6ba4b commit 6b041f7
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 57 deletions.
4 changes: 0 additions & 4 deletions src/common/snippets/include/snippets/op/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ class Load : public MemoryAccess {

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

protected:
void set_output_port_descriptor(const MemoryAccess::PortDescriptor& desc, const size_t i) override;
const MemoryAccess::PortDescriptor& get_output_port_descriptor(const size_t i) const override;
};

/**
Expand Down
11 changes: 5 additions & 6 deletions src/common/snippets/include/snippets/op/memory_access.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,15 @@ class MemoryAccess : public ngraph::op::Op {
size_t get_output_port_count() const { return m_output_ports.size(); }

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;

protected:
explicit MemoryAccess(const OutputVector& arguments);
explicit MemoryAccess(const OutputVector& arguments, size_t input_count = 0, size_t output_count = 0);
MemoryAccess() = default;

virtual void set_input_port_descriptor(const PortDescriptor& desc, const size_t i);
virtual void set_output_port_descriptor(const PortDescriptor& desc, const size_t i);
virtual const PortDescriptor& get_input_port_descriptor(const size_t i) const;
virtual const PortDescriptor& get_output_port_descriptor(const size_t i) const;
void set_input_port_descriptor(const PortDescriptor& desc, const size_t i);
void set_output_port_descriptor(const PortDescriptor& desc, const size_t i);
const PortDescriptor& get_input_port_descriptor(const size_t i) const;
const PortDescriptor& get_output_port_descriptor(const size_t i) const;

std::vector<PortDescriptor> m_input_ports;
std::vector<PortDescriptor> m_output_ports;
Expand Down
4 changes: 0 additions & 4 deletions src/common/snippets/include/snippets/op/store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ class Store : public MemoryAccess {

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

protected:
void set_input_port_descriptor(const MemoryAccess::PortDescriptor& desc, const size_t i) override;
const MemoryAccess::PortDescriptor& get_input_port_descriptor(const size_t i) const override;
};

} // namespace op
Expand Down
5 changes: 2 additions & 3 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@ namespace snippets {
namespace op {

Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
const size_t offset_a, const size_t offset_b, const size_t offset_c) : MemoryAccess({A, B}) {
const size_t offset_a, const size_t offset_b, const size_t offset_c) : MemoryAccess({A, B}, 2, 1) {
set_output_size(1);
constructor_validate_and_infer_types();
set_input_offset(offset_a, 0);
set_input_offset(offset_b, 1);
set_output_offset(offset_a, 0);
constructor_validate_and_infer_types();
}

void Brgemm::validate_and_infer_types() {
INTERNAL_OP_SCOPE(Brgemm_validate_and_infer_types);
MemoryAccess::validate_and_infer_types();
// If no leading dimensions are provided, assume dense row-major inputs-outputs
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(),
"Brgemm currently supports only static shapes.");
Expand Down
6 changes: 3 additions & 3 deletions src/common/snippets/src/op/broadcastload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
using namespace std;
using namespace ngraph;

snippets::op::BroadcastLoad::BroadcastLoad(const Output<Node>& x, ov::PartialShape shape, size_t offset) : MemoryAccess({x}), output_shape(std::move(shape)) {
constructor_validate_and_infer_types();
snippets::op::BroadcastLoad::BroadcastLoad(const Output<Node>& x, ov::PartialShape shape, size_t offset)
: MemoryAccess({x}, 1, 0), output_shape(std::move(shape)) {
set_input_port_descriptor({1, offset}, 0);
constructor_validate_and_infer_types();
}

bool snippets::op::BroadcastLoad::visit_attributes(AttributeVisitor& visitor) {
Expand All @@ -28,6 +29,5 @@ std::shared_ptr<Node> snippets::op::BroadcastLoad::clone_with_new_inputs(const O
}

void snippets::op::BroadcastLoad::validate_and_infer_types() {
MemoryAccess::validate_and_infer_types();
set_output_type(0, get_input_element_type(0), output_shape);
}
14 changes: 3 additions & 11 deletions src/common/snippets/src/op/load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ namespace ngraph {
namespace snippets {
namespace op {

Load::Load(const Output<Node>& x, const size_t count, const size_t offset) : MemoryAccess({x}) {
m_input_ports.resize(get_output_size());
Load::Load(const Output<Node>& x, const size_t count, const size_t offset) : MemoryAccess({x}, 1, 0) {
set_input_port_descriptor({count, offset}, 0);
constructor_validate_and_infer_types();
}
Expand All @@ -31,14 +30,6 @@ std::shared_ptr<Node> Load::clone_with_new_inputs(const OutputVector& new_args)
return std::make_shared<Load>(new_args.at(0), get_count(), get_offset());
}

void Load::set_output_port_descriptor(const MemoryAccess::PortDescriptor& desc, const size_t i) {
throw ov::Exception("Load node doesn't have memory access output port");
}

const MemoryAccess::PortDescriptor& Load::get_output_port_descriptor(const size_t i) const {
throw ov::Exception("Load node doesn't have memory access output port");
}

LoadReshape::LoadReshape(const Output<ov::Node>& x, const size_t count, const size_t offset, std::vector<size_t> order)
: Load(x, count, offset), m_order(std::move(order)) {
const auto& in_shape = x.get_partial_shape();
Expand All @@ -49,11 +40,12 @@ LoadReshape::LoadReshape(const Output<ov::Node>& x, const size_t count, const si
*std::min_element(m_order.begin(), m_order.end()) == 0, "LoadReshape detected invalid values in new_order");
const std::set<size_t> unique_dims(order.begin(), order.end());
NGRAPH_CHECK(unique_dims.size() == order.size(), "LoadReshape order must not contain repeated elements");
m_input_ports.resize(get_input_size());
set_input_port_descriptor({count, offset}, 0);
constructor_validate_and_infer_types();
}

void snippets::op::LoadReshape::validate_and_infer_types() {
MemoryAccess::validate_and_infer_types();
const auto& old_shape = get_input_partial_shape(0);
ov::PartialShape new_shape;
for (const auto idx : m_order)
Expand Down
9 changes: 1 addition & 8 deletions src/common/snippets/src/op/memory_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,13 @@ namespace ngraph {
namespace snippets {
namespace op {

MemoryAccess::MemoryAccess(const OutputVector& arguments) : Op(arguments) {}

void MemoryAccess::validate_and_infer_types() {
// We create descriptors in validate_and_infer_types() (instead of in ctor)
const auto input_count = get_input_size();
const auto output_count = get_output_size();
MemoryAccess::MemoryAccess(const OutputVector& arguments, size_t input_count, size_t output_count) : Op(arguments) {
while (m_input_ports.size() < input_count) {
m_input_ports.push_back({0, 0, m_input_ports.size()});
}
while (m_output_ports.size() < output_count) {
m_output_ports.push_back({0, 0, m_output_ports.size()});
}
OPENVINO_ASSERT(m_input_ports.size() == input_count, "The count of input ports must be equal to input count");
OPENVINO_ASSERT(m_output_ports.size() == output_count, "The count of output ports must be equal to output count");
}

bool MemoryAccess::visit_attributes(AttributeVisitor& visitor) {
Expand Down
11 changes: 1 addition & 10 deletions src/common/snippets/src/op/store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ namespace ngraph {
namespace snippets {
namespace op {

snippets::op::Store::Store(const Output<Node>& x, const size_t count, const size_t offset) : MemoryAccess({x}) {
m_output_ports.resize(get_output_size());
snippets::op::Store::Store(const Output<Node>& x, const size_t count, const size_t offset) : MemoryAccess({x}, 0, 1) {
set_output_port_descriptor({count, offset}, 0);
constructor_validate_and_infer_types();
}
Expand All @@ -31,14 +30,6 @@ std::shared_ptr<Node> snippets::op::Store::clone_with_new_inputs(const OutputVec
return std::make_shared<Store>(new_args.at(0), get_count(), get_offset());
}

void Store::set_input_port_descriptor(const MemoryAccess::PortDescriptor& desc, const size_t i) {
throw ov::Exception("Store node doesn't have memory access input port");
}

const MemoryAccess::PortDescriptor& Store::get_input_port_descriptor(const size_t i) const {
throw ov::Exception("Store node doesn't have memory access input port");
}

} // namespace op
} // namespace snippets
} // namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,27 @@ using namespace ov;

intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const Type type,
const size_t offset_in, const size_t offset_out0, const size_t offset_out1)
: ngraph::snippets::op::MemoryAccess({x}), m_type(type), m_src_type(src_type) {
set_output_size(is_with_compensations() ? 2 : 1);
constructor_validate_and_infer_types();
: ngraph::snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), m_type(type), m_src_type(src_type) {
set_output_size(get_output_port_count());
m_input_ports.resize(get_input_size());
m_output_ports.resize(get_output_size());
set_input_port_descriptor({0, offset_in}, 0);
set_output_port_descriptor({0, offset_out0}, 0);
if (is_with_compensations()) {
set_output_port_descriptor({0, offset_out1}, 1);
}
constructor_validate_and_infer_types();
}

bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes);
MemoryAccess::visit_attributes(visitor);
visit_attributes(visitor);
visitor.on_attribute("src_type", m_src_type);
return true;
}

void intel_cpu::BrgemmCopyB::validate_and_infer_types() {
INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types);
MemoryAccess::validate_and_infer_types();

const auto element_type = get_input_element_type(0);
NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,30 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type ty
// We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call
set_arguments({A, B});
set_output_size(1);
constructor_validate_and_infer_types();
m_input_ports.resize(get_input_size());
m_output_ports.resize(get_output_size());
set_input_port_descriptor({0, offset_a}, 0);
set_input_port_descriptor({0, offset_b}, 1);
set_output_port_descriptor({0, offset_c}, 0);
constructor_validate_and_infer_types();
}

BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
const size_t offset_a, const size_t offset_b, const size_t offset_scratch, const size_t offset_c)
: Brgemm(), m_type(type) {
set_arguments({A, B, scratch});
set_output_size(1);
constructor_validate_and_infer_types();
m_input_ports.resize(get_input_size());
m_output_ports.resize(get_output_size());
set_input_port_descriptor({0, offset_a}, 0);
set_input_port_descriptor({0, offset_b}, 1);
set_output_port_descriptor({0, offset_c}, 0);
set_input_port_descriptor({0, offset_scratch}, 2);
constructor_validate_and_infer_types();
}

void BrgemmCPU::validate_and_infer_types() {
INTERNAL_OP_SCOPE(BrgemmCPU_validate_and_infer_types);
MemoryAccess::validate_and_infer_types();
// If no leading dimensions are provided, assume dense row-major inputs-outputs
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(),
"BrgemmCPU currently supports only static shapes.");
Expand Down

0 comments on commit 6b041f7

Please sign in to comment.