From e40addb1c83d582ee6ca0096ecb27c42f192420b Mon Sep 17 00:00:00 2001 From: "Raasz, Pawel" Date: Mon, 9 Dec 2024 10:20:23 +0000 Subject: [PATCH] Add custom mask for MaxPool for CPU data dependency only Signed-off-by: Raasz, Pawel --- .../src/shape_inference/shape_inference.cpp | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp b/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp index 9f961fde2bdb7d..d34e4459cf76c7 100644 --- a/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp @@ -307,7 +307,8 @@ class ShapeInferTA : public ShapeInferBase { /** * @brief Shape inference for v0 FakeQuantize. * - * It requires dedicated port mask for data dependency but not use by shape inference function. + * It requires dedicated port mask for CPU output shape data dependency but is not used by inference function. + * Review shape_infer function to include this dependency. */ template <> class ShapeInferTA : public ShapeInferBase { @@ -379,6 +380,48 @@ class ShapeInferPaddingTA : public ShapeInferPaddingBase { } }; +/** + * @brief Shape inference for v14 MaxPool. + * + * It requires dedicated port mask for CPU output shape data dependency but is not used by inference function. + * Review shape_infer function to include this dependency. + */ +template <> +class ShapeInferPaddingTA : public ShapeInferPaddingBase { +public: + using ShapeInferPaddingBase::ShapeInferPaddingBase; + + ov::optional> infer(const std::vector& input_shapes, + const ov::ITensorAccessor&) override { + return {shape_infer(static_cast(m_node.get()), input_shapes, m_pads_begin, m_pads_end)}; + } + + port_mask_t get_port_mask() const override { + return util::bit::mask(0); + } +}; + +/** + * @brief Shape inference for v8 MaxPool. + * + * It requires dedicated port mask for CPU output shape data dependency but is not used by inference function. + * Review shape_infer function to include this dependency. + */ +template <> +class ShapeInferPaddingTA : public ShapeInferPaddingBase { +public: + using ShapeInferPaddingBase::ShapeInferPaddingBase; + + ov::optional> infer(const std::vector& input_shapes, + const ov::ITensorAccessor&) override { + return {shape_infer(static_cast(m_node.get()), input_shapes, m_pads_begin, m_pads_end)}; + } + + port_mask_t get_port_mask() const override { + return util::bit::mask(0); + } +}; + /** * \brief Shape infer factory *