Skip to content

Commit

Permalink
[Op] Extend CPU Plugin with Col2Im reference implementation (openvi…
Browse files Browse the repository at this point in the history
…notoolkit#24931)

### Details:
 - Register `Col2im-15` in CPU Plugin
- Use reference implementation, used
openvinotoolkit#23844 as reference
 - Add tests

### Tickets:
 - CVS-142438

### Related PRs:
 - openvinotoolkit#24548
 - openvinotoolkit#24197
 - openvinotoolkit#23947
 - openvinotoolkit#24569

---------

Co-authored-by: Michal Lukaszewski <[email protected]>
Co-authored-by: Maksim Kutakov <[email protected]>
  • Loading branch information
3 people authored Jul 3, 2024
1 parent 1e12297 commit ba05386
Show file tree
Hide file tree
Showing 9 changed files with 531 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"NV12toBGR", Type::ColorConvert},
{"I420toRGB", Type::ColorConvert},
{"I420toBGR", Type::ColorConvert},
{"Col2Im", Type::Col2Im},
{"MVN", Type::MVN},
{"NormalizeL2", Type::NormalizeL2},
{"ScatterUpdate", Type::ScatterUpdate},
Expand Down Expand Up @@ -305,6 +306,7 @@ std::string NameFromType(const Type type) {
CASE(MVN);
CASE(TensorIterator);
CASE(Convert);
CASE(Col2Im);
CASE(ColorConvert);
CASE(NormalizeL2);
CASE(ScatterUpdate);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ enum class Type {
TensorIterator,
Convert,
ColorConvert,
Col2Im,
MVN,
NormalizeL2,
ScatterUpdate,
Expand Down
110 changes: 110 additions & 0 deletions src/plugins/intel_cpu/src/nodes/col2im.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "col2im.h"
#include "openvino/reference/col2im.hpp"
#include "openvino/op/col2im.hpp"

namespace ov {
namespace intel_cpu {
namespace node {
Col2Im::Col2Im(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage);
}
const auto col2Im = ov::as_type_ptr<const ov::op::v15::Col2Im>(op);
strides = col2Im->get_strides();
dilations = col2Im->get_dilations();
padsBegin = col2Im->get_pads_begin();
padsEnd = col2Im->get_pads_end();
}

bool Col2Im::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
if (!ov::is_type<ov::op::v15::Col2Im>(op)) {
errorMessage = "Only opset15 Col2Im operation is supported";
return false;
}
} catch (...) {
return false;
}
return true;
}

void Col2Im::getSupportedDescriptors() {
// Validation is already done in the ov::opset15::Col2Im.
}

void Col2Im::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
ov::element::Type dataPrecision = getOriginalInputPrecisionAtPort(0);
addSupportedPrimDesc(
{{LayoutType::ncsp, dataPrecision}, {LayoutType::ncsp, ov::element::i32}, {LayoutType::ncsp, ov::element::i32}},
{{LayoutType::ncsp, dataPrecision}},
impl_desc_type::ref);
}

bool Col2Im::created() const {
return getType() == Type::Col2Im;
}

bool Col2Im::needPrepareParams() const {
return false;
}

void Col2Im::executeDynamicImpl(dnnl::stream strm) {
execute(strm);
}

template <class T, class T_idx>
void Col2Im::executeImpl() {
ov::reference::col2im<T, T_idx>(
getSrcDataAtPortAs<const T>(0),
ov::Shape{getSrcMemoryAtPort(0)->getStaticDims()},
getSrcDataAtPortAs<const T_idx>(1),
getSrcDataAtPortAs<const T_idx>(2),
getDstDataAtPortAs<T>(0),
strides,
dilations,
padsBegin,
padsEnd);
}

namespace {
struct Col2ImContext {
Col2Im &node;
};
}

template<typename T>
struct Col2Im::Col2ImExecute {
using TData = typename std::tuple_element<0, T>::type;
using TIndex = typename std::tuple_element<1, T>::type;

void operator()(Col2ImContext & ctx) {
ctx.node.executeImpl<TData, TIndex>();
}
};
void Col2Im::execute(dnnl::stream strm) {
auto dataPrecision = getParentEdgeAt(0)->getMemory().getDesc().getPrecision();
auto indexPrecision = getParentEdgeAt(1)->getMemory().getDesc().getPrecision();

Col2ImContext ctx = {
*this
};

OV_SWITCH(intel_cpu, Col2ImExecute, ctx, std::tie(dataPrecision, indexPrecision),
OV_CASE2(ov::element::f32, ov::element::i32, float, int32_t),
OV_CASE2(ov::element::f16, ov::element::i32, ov::float16, int32_t),
OV_CASE2(ov::element::bf16, ov::element::i32, ov::bfloat16, int32_t),
OV_CASE2(ov::element::i32, ov::element::i32, int32_t, int32_t),
OV_CASE2(ov::element::i8, ov::element::i32, int8_t, int32_t),
OV_CASE2(ov::element::u8, ov::element::i32, uint8_t, int32_t))
}
} // namespace node
} // namespace intel_cpu
} // namespace ov
40 changes: 40 additions & 0 deletions src/plugins/intel_cpu/src/nodes/col2im.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "node.h"

namespace ov {
namespace intel_cpu {
namespace node {

class Col2Im : public Node {
public:
Col2Im(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);

static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override;
void execute(dnnl::stream strm) override;
bool created() const override;
bool needPrepareParams() const override;
void executeDynamicImpl(dnnl::stream strm) override;

private:
template <class OV_DATA_TYPE, class OV_INDEX_TYPE>
void executeImpl();

template<typename T>
struct Col2ImExecute;

ov::Strides strides;
ov::Strides dilations;
ov::Shape padsBegin;
ov::Shape padsEnd;
};

} // namespace node
} // namespace intel_cpu
} // namespace ov
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/nodes_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "nodes/bin_conv.h"
#include "nodes/broadcast.h"
#include "nodes/bucketize.h"
#include "nodes/col2im.h"
#include "nodes/color_convert.h"
#include "nodes/concat.h"
#include "nodes/conv.h"
Expand Down Expand Up @@ -160,6 +161,7 @@ Node::NodesFactory::NodesFactory() : Factory("NodesFactory") {
INTEL_CPU_NODE(Math, Type::Math);
INTEL_CPU_NODE(MultiClassNms, Type::MulticlassNms);
INTEL_CPU_NODE(Convert, Type::Convert);
INTEL_CPU_NODE(Col2Im, Type::Col2Im);
INTEL_CPU_NODE(ColorConvert, Type::ColorConvert);
INTEL_CPU_NODE(EmbeddingBagOffset, Type::EmbeddingBagOffsetsSum);
INTEL_CPU_NODE(EmbeddingBagOffset, Type::EmbeddingBagOffsets);
Expand Down
Loading

0 comments on commit ba05386

Please sign in to comment.