Skip to content

Commit

Permalink
Add custom shape inference for deconv node
Browse files Browse the repository at this point in the history
Signed-off-by: Raasz, Pawel <[email protected]>
  • Loading branch information
praasz committed Dec 6, 2024
1 parent da7e5a8 commit 3829b7b
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion src/plugins/intel_cpu/src/nodes/deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,60 @@ bool DeconvKey::operator==(const DeconvKey &rhs) const {
retVal = retVal && *attr.get() == *rhs.attr.get() && implType == rhs.implType;
return retVal;
}

// class FCShapeInfer : public ShapeInferEmptyPads {
// public:
// FCShapeInfer(size_t outPut_rank) : out_rank(outPut_rank) {}
// Result infer(const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
// const std::unordered_map<size_t, MemoryPtr>& data_dependency) override;

// port_mask_t get_port_mask() const override {
// return EMPTY_PORT_MASK;
// }

// private:
// size_t out_rank = 0;
// };

class DeconvolutionShapeInferFactory : public ShapeInferFactory {
public:
DeconvolutionShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(std::move(op)) {}

ShapeInferPtr makeShapeInfer() const override {
return std::make_shared<DeconvolutionShapeInfer>(m_op);
}

private:
class DeconvolutionShapeInfer : public IShapeInfer {
public:
DeconvolutionShapeInfer(const std::shared_ptr<ov::Node>& op)
: m_shape_infer(make_shape_inference(op)),
m_port_mask((op->get_input_size() > 2) ? PortMask(2) : EMPTY_PORT_MASK) {}

Result infer(const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
return m_shape_infer->infer(input_shapes, data_dependency);
}

const ov::CoordinateDiff& get_pads_begin() override {
return m_shape_infer->get_pads_begin();
}

const ov::CoordinateDiff& get_pads_end() override {
return m_shape_infer->get_pads_end();
}

port_mask_t get_port_mask() const override {
return m_port_mask;
};

private:
ShapeInferPtr m_shape_infer;
const port_mask_t m_port_mask;
};

std::shared_ptr<ov::Node> m_op;
};
} // namespace

bool Deconvolution::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
Expand All @@ -145,7 +199,7 @@ bool Deconvolution::isSupportedOperation(const std::shared_ptr<const ov::Node>&
}

Deconvolution::Deconvolution(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op)) {
: Node(op, context, DeconvolutionShapeInferFactory(op)) {
std::string errorMessage;
errorPrefix = "Deconvolution node with name '" + getName() + "' ";
if (!isSupportedOperation(op, errorMessage))
Expand Down

0 comments on commit 3829b7b

Please sign in to comment.