-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU]: Added ROIAlignRotated basic impl. #23844
Changes from all commits
4a5775b
e0c13d1
e713a7b
fb113e6
b09db55
f5e31b9
77a4223
5943fa8
eb7720d
ce5dcc6
014f931
d3a3c57
7b56861
0756ecf
9008667
ac066f7
e0e78a3
190cc6b
af1b95e
ee705be
38206ee
5e369cc
074a209
995c542
2b4aec4
c2a1228
4b831c5
4ecbfa7
242d91f
fbe71f8
dc2600a
456b71b
bd0bb00
1483b8d
75186a3
8a19209
ec02c6c
7e78b9c
3f067ae
6f30a18
3f44778
7c07606
047237b
7d88f29
5e8aa9a
7fa0b88
cfd0104
0f1d0bc
848af5a
947862c
a7de6e2
44d34fa
fe8159b
aca6338
42a534b
f6dd302
5624e0a
328dabf
6ad51f3
33960ed
a9a6316
c68e148
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,7 @@ enum class Type { | |
NonZero, | ||
Tile, | ||
ROIAlign, | ||
ROIAlignRotated, | ||
ROIPooling, | ||
PSROIPooling, | ||
BatchToSpace, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "roi_align_rotated.h" | ||
|
||
#include <openvino/opsets/opset14.hpp> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please include just required operation headers. |
||
|
||
#include "common/cpu_convert.h" | ||
#include "openvino/reference/roi_align.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to add fn |
||
ROIAlignRotated::ROIAlignRotated(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context) | ||
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { | ||
const auto roiAlign = ov::as_type_ptr<const ov::opset14::ROIAlignRotated>(op); | ||
pooledH = roiAlign->get_pooled_h(); | ||
pooledW = roiAlign->get_pooled_w(); | ||
spatialScale = roiAlign->get_spatial_scale(); | ||
samplingRatio = roiAlign->get_sampling_ratio(); | ||
clockwiseMode = roiAlign->get_clockwise_mode(); | ||
} | ||
|
||
void ROIAlignRotated::getSupportedDescriptors() { | ||
// Validation is already done in the ov::opset14::ROIAlignRotated. | ||
} | ||
|
||
void ROIAlignRotated::initSupportedPrimitiveDescriptors() { | ||
if (!supportedPrimitiveDescriptors.empty()) | ||
return; | ||
|
||
ov::element::Type inputPrec0 = getOriginalInputPrecisionAtPort(0); | ||
ov::element::Type outputPrec = getOriginalOutputPrecisionAtPort(0); | ||
|
||
addSupportedPrimDesc( | ||
{{LayoutType::ncsp, inputPrec0}, {LayoutType::ncsp, ov::element::f32}, {LayoutType::ncsp, ov::element::i32}}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per specification, 1st and 2nd inputs have the same precisions. Why the second one is always f32? |
||
{{LayoutType::ncsp, outputPrec}}, | ||
impl_desc_type::ref); | ||
} | ||
|
||
bool ROIAlignRotated::created() const { | ||
return getType() == Type::ROIAlignRotated; | ||
} | ||
|
||
bool ROIAlignRotated::needPrepareParams() const { | ||
return false; | ||
} | ||
|
||
void ROIAlignRotated::executeDynamicImpl(dnnl::stream strm) { | ||
execute(strm); | ||
} | ||
|
||
template <ov::element::Type_t OV_TYPE> | ||
void ROIAlignRotated::executeImpl() { | ||
using T = typename ov::element_type_traits<OV_TYPE>::value_type; | ||
|
||
const size_t batch_indices_size = getSrcMemoryAtPort(2)->getShape().getElementsCount(); | ||
|
||
std::vector<int64_t> batch_indices_vec_scaled_up(batch_indices_size); | ||
cpu_convert(getSrcMemoryAtPort(2)->getData(), | ||
batch_indices_vec_scaled_up.data(), | ||
getSrcMemoryAtPort(2)->getPrecision(), | ||
ov::element::i64, | ||
batch_indices_size); | ||
|
||
ov::reference::roi_align<T, ov::reference::roi_policy::ROIAlignRotatedOpDefPolicy>( | ||
getSrcDataAtPortAs<const T>(0), | ||
getSrcDataAtPortAs<const T>(1), | ||
batch_indices_vec_scaled_up.data(), | ||
getDstDataAtPortAs<T>(0), | ||
ov::Shape{getSrcMemoryAtPort(0)->getStaticDims()}, | ||
ov::Shape{getSrcMemoryAtPort(1)->getStaticDims()}, | ||
ov::Shape{getSrcMemoryAtPort(2)->getStaticDims()}, | ||
ov::Shape{getDstMemoryAtPort(0)->getStaticDims()}, | ||
pooledH, | ||
pooledW, | ||
samplingRatio, | ||
spatialScale, | ||
ov::op::v3::ROIAlign::PoolingMode::AVG, | ||
ov::op::v9::ROIAlign::AlignedMode::ASYMMETRIC, | ||
clockwiseMode); | ||
} | ||
|
||
void ROIAlignRotated::execute(dnnl::stream) { | ||
const ov::element::Type type = getOriginalInputPrecisionAtPort(0); | ||
executeImpl<ov::element::f32>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this fn is called twice? |
||
|
||
#define CASE(OV_TYPE) \ | ||
case ov::element::OV_TYPE: \ | ||
executeImpl<ov::element::OV_TYPE>(); \ | ||
break; | ||
|
||
switch (type) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use OV_SWITCH instead |
||
CASE(bf16); | ||
CASE(f16); | ||
CASE(f32); | ||
CASE(f64); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not supported by plugin |
||
default: | ||
OPENVINO_THROW("[ROIAlignRotated]: Unhandled data type ", type, " in execute()"); | ||
} | ||
#undef CASE | ||
} | ||
|
||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "node.h" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
|
||
class ROIAlignRotated : public Node { | ||
public: | ||
ROIAlignRotated(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context); | ||
|
||
void getSupportedDescriptors() override; | ||
void initSupportedPrimitiveDescriptors() override; | ||
void execute(dnnl::stream strm) override; | ||
bool created() const override; | ||
bool needPrepareParams() const override; | ||
void executeDynamicImpl(dnnl::stream strm) override; | ||
|
||
private: | ||
template <ov::element::Type_t OV_TYPE> | ||
void executeImpl(); | ||
|
||
int pooledH; | ||
int pooledW; | ||
int samplingRatio; | ||
float spatialScale; | ||
bool clockwiseMode; | ||
}; | ||
|
||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a ticket number
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CVS-141656