Skip to content

Commit

Permalink
Add custom mask for MaxPool for CPU data dependency only
Browse files Browse the repository at this point in the history
Signed-off-by: Raasz, Pawel <[email protected]>
  • Loading branch information
praasz committed Dec 9, 2024
1 parent 3829b7b commit e40addb
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ class ShapeInferTA<TOp, EMPTY_PORT_MASK> : public ShapeInferBase {
/**
* @brief Shape inference for v0 FakeQuantize.
*
* It requires dedicated port mask for data dependency but not use by shape inference function.
* It requires dedicated port mask for CPU output shape data dependency but is not used by inference function.
* Review shape_infer function to include this dependency.
*/
template <>
class ShapeInferTA<ov::op::v0::FakeQuantize, EMPTY_PORT_MASK> : public ShapeInferBase {
Expand Down Expand Up @@ -379,6 +380,48 @@ class ShapeInferPaddingTA<TOp, EMPTY_PORT_MASK> : public ShapeInferPaddingBase {
}
};

/**
* @brief Shape inference for v14 MaxPool.
*
* It requires dedicated port mask for CPU output shape data dependency but is not used by inference function.
* Review shape_infer function to include this dependency.
*/
template <>
class ShapeInferPaddingTA<ov::op::v14::MaxPool, EMPTY_PORT_MASK> : public ShapeInferPaddingBase {
public:
using ShapeInferPaddingBase::ShapeInferPaddingBase;

ov::optional<std::vector<StaticShape>> infer(const std::vector<StaticShapeRef>& input_shapes,
const ov::ITensorAccessor&) override {
return {shape_infer(static_cast<ov::op::v14::MaxPool*>(m_node.get()), input_shapes, m_pads_begin, m_pads_end)};
}

port_mask_t get_port_mask() const override {
return util::bit::mask(0);
}
};

/**
* @brief Shape inference for v8 MaxPool.
*
* It requires dedicated port mask for CPU output shape data dependency but is not used by inference function.
* Review shape_infer function to include this dependency.
*/
template <>
class ShapeInferPaddingTA<ov::op::v8::MaxPool, EMPTY_PORT_MASK> : public ShapeInferPaddingBase {
public:
using ShapeInferPaddingBase::ShapeInferPaddingBase;

ov::optional<std::vector<StaticShape>> infer(const std::vector<StaticShapeRef>& input_shapes,
const ov::ITensorAccessor&) override {
return {shape_infer(static_cast<ov::op::v8::MaxPool*>(m_node.get()), input_shapes, m_pads_begin, m_pads_end)};
}

port_mask_t get_port_mask() const override {
return util::bit::mask(0);
}
};

/**
* \brief Shape infer factory
*
Expand Down

0 comments on commit e40addb

Please sign in to comment.