diff --git a/src/common/snippets/include/snippets/op/load.hpp b/src/common/snippets/include/snippets/op/load.hpp index 17d052852228ee..38acd0e8a10255 100644 --- a/src/common/snippets/include/snippets/op/load.hpp +++ b/src/common/snippets/include/snippets/op/load.hpp @@ -33,10 +33,6 @@ class Load : public MemoryAccess { void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; - -protected: - void set_output_port_descriptor(const MemoryAccess::PortDescriptor& desc, const size_t i) override; - const MemoryAccess::PortDescriptor& get_output_port_descriptor(const size_t i) const override; }; /** diff --git a/src/common/snippets/include/snippets/op/memory_access.hpp b/src/common/snippets/include/snippets/op/memory_access.hpp index e02bb5635e5eb2..7b090c8f65d528 100644 --- a/src/common/snippets/include/snippets/op/memory_access.hpp +++ b/src/common/snippets/include/snippets/op/memory_access.hpp @@ -59,16 +59,15 @@ class MemoryAccess : public ngraph::op::Op { size_t get_output_port_count() const { return m_output_ports.size(); } bool visit_attributes(AttributeVisitor& visitor) override; - void validate_and_infer_types() override; protected: - explicit MemoryAccess(const OutputVector& arguments); + explicit MemoryAccess(const OutputVector& arguments, size_t input_count = 0, size_t output_count = 0); MemoryAccess() = default; - virtual void set_input_port_descriptor(const PortDescriptor& desc, const size_t i); - virtual void set_output_port_descriptor(const PortDescriptor& desc, const size_t i); - virtual const PortDescriptor& get_input_port_descriptor(const size_t i) const; - virtual const PortDescriptor& get_output_port_descriptor(const size_t i) const; + void set_input_port_descriptor(const PortDescriptor& desc, const size_t i); + void set_output_port_descriptor(const PortDescriptor& desc, const size_t i); + const PortDescriptor& get_input_port_descriptor(const size_t i) const; + const PortDescriptor& get_output_port_descriptor(const size_t i) const; std::vector m_input_ports; std::vector m_output_ports; diff --git a/src/common/snippets/include/snippets/op/store.hpp b/src/common/snippets/include/snippets/op/store.hpp index 7d0883f6dbaac7..b62a4c6ccb18b7 100644 --- a/src/common/snippets/include/snippets/op/store.hpp +++ b/src/common/snippets/include/snippets/op/store.hpp @@ -33,10 +33,6 @@ class Store : public MemoryAccess { void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; - -protected: - void set_input_port_descriptor(const MemoryAccess::PortDescriptor& desc, const size_t i) override; - const MemoryAccess::PortDescriptor& get_input_port_descriptor(const size_t i) const override; }; } // namespace op diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 800bd66d14bcfc..743653099b8601 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -13,17 +13,16 @@ namespace snippets { namespace op { Brgemm::Brgemm(const Output& A, const Output& B, - const size_t offset_a, const size_t offset_b, const size_t offset_c) : MemoryAccess({A, B}) { + const size_t offset_a, const size_t offset_b, const size_t offset_c) : MemoryAccess({A, B}, 2, 1) { set_output_size(1); - constructor_validate_and_infer_types(); set_input_offset(offset_a, 0); set_input_offset(offset_b, 1); set_output_offset(offset_a, 0); + constructor_validate_and_infer_types(); } void Brgemm::validate_and_infer_types() { INTERNAL_OP_SCOPE(Brgemm_validate_and_infer_types); - MemoryAccess::validate_and_infer_types(); // If no leading dimensions are provided, assume dense row-major inputs-outputs NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), "Brgemm currently supports only static shapes."); diff --git a/src/common/snippets/src/op/broadcastload.cpp b/src/common/snippets/src/op/broadcastload.cpp index f24ff3fc46a000..ccbb5f9b9af9a7 100644 --- a/src/common/snippets/src/op/broadcastload.cpp +++ b/src/common/snippets/src/op/broadcastload.cpp @@ -11,9 +11,10 @@ using namespace std; using namespace ngraph; -snippets::op::BroadcastLoad::BroadcastLoad(const Output& x, ov::PartialShape shape, size_t offset) : MemoryAccess({x}), output_shape(std::move(shape)) { - constructor_validate_and_infer_types(); +snippets::op::BroadcastLoad::BroadcastLoad(const Output& x, ov::PartialShape shape, size_t offset) + : MemoryAccess({x}, 1, 0), output_shape(std::move(shape)) { set_input_port_descriptor({1, offset}, 0); + constructor_validate_and_infer_types(); } bool snippets::op::BroadcastLoad::visit_attributes(AttributeVisitor& visitor) { @@ -28,6 +29,5 @@ std::shared_ptr snippets::op::BroadcastLoad::clone_with_new_inputs(const O } void snippets::op::BroadcastLoad::validate_and_infer_types() { - MemoryAccess::validate_and_infer_types(); set_output_type(0, get_input_element_type(0), output_shape); } diff --git a/src/common/snippets/src/op/load.cpp b/src/common/snippets/src/op/load.cpp index e135c95149863c..f1f5bc42c7a3da 100644 --- a/src/common/snippets/src/op/load.cpp +++ b/src/common/snippets/src/op/load.cpp @@ -12,8 +12,7 @@ namespace ngraph { namespace snippets { namespace op { -Load::Load(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}) { - m_input_ports.resize(get_output_size()); +Load::Load(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}, 1, 0) { set_input_port_descriptor({count, offset}, 0); constructor_validate_and_infer_types(); } @@ -31,14 +30,6 @@ std::shared_ptr Load::clone_with_new_inputs(const OutputVector& new_args) return std::make_shared(new_args.at(0), get_count(), get_offset()); } -void Load::set_output_port_descriptor(const MemoryAccess::PortDescriptor& desc, const size_t i) { - throw ov::Exception("Load node doesn't have memory access output port"); -} - -const MemoryAccess::PortDescriptor& Load::get_output_port_descriptor(const size_t i) const { - throw ov::Exception("Load node doesn't have memory access output port"); -} - LoadReshape::LoadReshape(const Output& x, const size_t count, const size_t offset, std::vector order) : Load(x, count, offset), m_order(std::move(order)) { const auto& in_shape = x.get_partial_shape(); @@ -49,11 +40,12 @@ LoadReshape::LoadReshape(const Output& x, const size_t count, const si *std::min_element(m_order.begin(), m_order.end()) == 0, "LoadReshape detected invalid values in new_order"); const std::set unique_dims(order.begin(), order.end()); NGRAPH_CHECK(unique_dims.size() == order.size(), "LoadReshape order must not contain repeated elements"); + m_input_ports.resize(get_input_size()); + set_input_port_descriptor({count, offset}, 0); constructor_validate_and_infer_types(); } void snippets::op::LoadReshape::validate_and_infer_types() { - MemoryAccess::validate_and_infer_types(); const auto& old_shape = get_input_partial_shape(0); ov::PartialShape new_shape; for (const auto idx : m_order) diff --git a/src/common/snippets/src/op/memory_access.cpp b/src/common/snippets/src/op/memory_access.cpp index 734e2cedc08b10..ea0e4649f9e5de 100644 --- a/src/common/snippets/src/op/memory_access.cpp +++ b/src/common/snippets/src/op/memory_access.cpp @@ -9,20 +9,13 @@ namespace ngraph { namespace snippets { namespace op { -MemoryAccess::MemoryAccess(const OutputVector& arguments) : Op(arguments) {} - -void MemoryAccess::validate_and_infer_types() { - // We create descriptors in validate_and_infer_types() (instead of in ctor) - const auto input_count = get_input_size(); - const auto output_count = get_output_size(); +MemoryAccess::MemoryAccess(const OutputVector& arguments, size_t input_count, size_t output_count) : Op(arguments) { while (m_input_ports.size() < input_count) { m_input_ports.push_back({0, 0, m_input_ports.size()}); } while (m_output_ports.size() < output_count) { m_output_ports.push_back({0, 0, m_output_ports.size()}); } - OPENVINO_ASSERT(m_input_ports.size() == input_count, "The count of input ports must be equal to input count"); - OPENVINO_ASSERT(m_output_ports.size() == output_count, "The count of output ports must be equal to output count"); } bool MemoryAccess::visit_attributes(AttributeVisitor& visitor) { diff --git a/src/common/snippets/src/op/store.cpp b/src/common/snippets/src/op/store.cpp index f9a013dedef22e..8ac2c4cdf1704e 100644 --- a/src/common/snippets/src/op/store.cpp +++ b/src/common/snippets/src/op/store.cpp @@ -12,8 +12,7 @@ namespace ngraph { namespace snippets { namespace op { -snippets::op::Store::Store(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}) { - m_output_ports.resize(get_output_size()); +snippets::op::Store::Store(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}, 0, 1) { set_output_port_descriptor({count, offset}, 0); constructor_validate_and_infer_types(); } @@ -31,14 +30,6 @@ std::shared_ptr snippets::op::Store::clone_with_new_inputs(const OutputVec return std::make_shared(new_args.at(0), get_count(), get_offset()); } -void Store::set_input_port_descriptor(const MemoryAccess::PortDescriptor& desc, const size_t i) { - throw ov::Exception("Store node doesn't have memory access input port"); -} - -const MemoryAccess::PortDescriptor& Store::get_input_port_descriptor(const size_t i) const { - throw ov::Exception("Store node doesn't have memory access input port"); -} - } // namespace op } // namespace snippets } // namespace ngraph diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp index 0d19f1a6999e9f..0e4004395e188a 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp @@ -14,14 +14,16 @@ using namespace ov; intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type src_type, const Type type, const size_t offset_in, const size_t offset_out0, const size_t offset_out1) - : ngraph::snippets::op::MemoryAccess({x}), m_type(type), m_src_type(src_type) { - set_output_size(is_with_compensations() ? 2 : 1); - constructor_validate_and_infer_types(); + : ngraph::snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), m_type(type), m_src_type(src_type) { + set_output_size(get_output_port_count()); + m_input_ports.resize(get_input_size()); + m_output_ports.resize(get_output_size()); set_input_port_descriptor({0, offset_in}, 0); set_output_port_descriptor({0, offset_out0}, 0); if (is_with_compensations()) { set_output_port_descriptor({0, offset_out1}, 1); } + constructor_validate_and_infer_types(); } bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) { @@ -33,7 +35,6 @@ bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) { void intel_cpu::BrgemmCopyB::validate_and_infer_types() { INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types); - MemoryAccess::validate_and_infer_types(); const auto element_type = get_input_element_type(0); NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8), diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp index ceddc612d831b5..67e85394063c66 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp @@ -19,10 +19,12 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Type ty // We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call set_arguments({A, B}); set_output_size(1); - constructor_validate_and_infer_types(); + m_input_ports.resize(get_input_size()); + m_output_ports.resize(get_output_size()); set_input_port_descriptor({0, offset_a}, 0); set_input_port_descriptor({0, offset_b}, 1); set_output_port_descriptor({0, offset_c}, 0); + constructor_validate_and_infer_types(); } BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output& scratch, const Type type, @@ -30,16 +32,17 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output< : Brgemm(), m_type(type) { set_arguments({A, B, scratch}); set_output_size(1); - constructor_validate_and_infer_types(); + m_input_ports.resize(get_input_size()); + m_output_ports.resize(get_output_size()); set_input_port_descriptor({0, offset_a}, 0); set_input_port_descriptor({0, offset_b}, 1); set_output_port_descriptor({0, offset_c}, 0); set_input_port_descriptor({0, offset_scratch}, 2); + constructor_validate_and_infer_types(); } void BrgemmCPU::validate_and_infer_types() { INTERNAL_OP_SCOPE(BrgemmCPU_validate_and_infer_types); - MemoryAccess::validate_and_infer_types(); // If no leading dimensions are provided, assume dense row-major inputs-outputs NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), "BrgemmCPU currently supports only static shapes.");