-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CPU]: Added ROIAlignRotated basic impl. (#23844)
### Details: - Added basic implementation of ROIAlignRotated - It uses implementation from core::reference - the goal was just to add support for the op, not the optimized implementation. - No unit tests added, since impl is already tested by core::reference and functional tests. ### Tickets: - [CVS-135847](https://jira.devtools.intel.com/browse/CVS-135847) --------- Co-authored-by: Pawel Raasz <[email protected]>
- Loading branch information
Showing
10 changed files
with
174 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,7 @@ enum class Type { | |
NonZero, | ||
Tile, | ||
ROIAlign, | ||
ROIAlignRotated, | ||
ROIPooling, | ||
PSROIPooling, | ||
BatchToSpace, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "roi_align_rotated.h" | ||
|
||
#include <openvino/opsets/opset14.hpp> | ||
|
||
#include "common/cpu_convert.h" | ||
#include "openvino/reference/roi_align.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
|
||
ROIAlignRotated::ROIAlignRotated(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context) | ||
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { | ||
const auto roiAlign = ov::as_type_ptr<const ov::opset14::ROIAlignRotated>(op); | ||
pooledH = roiAlign->get_pooled_h(); | ||
pooledW = roiAlign->get_pooled_w(); | ||
spatialScale = roiAlign->get_spatial_scale(); | ||
samplingRatio = roiAlign->get_sampling_ratio(); | ||
clockwiseMode = roiAlign->get_clockwise_mode(); | ||
} | ||
|
||
void ROIAlignRotated::getSupportedDescriptors() { | ||
// Validation is already done in the ov::opset14::ROIAlignRotated. | ||
} | ||
|
||
void ROIAlignRotated::initSupportedPrimitiveDescriptors() { | ||
if (!supportedPrimitiveDescriptors.empty()) | ||
return; | ||
|
||
ov::element::Type inputPrec0 = getOriginalInputPrecisionAtPort(0); | ||
ov::element::Type outputPrec = getOriginalOutputPrecisionAtPort(0); | ||
|
||
addSupportedPrimDesc( | ||
{{LayoutType::ncsp, inputPrec0}, {LayoutType::ncsp, ov::element::f32}, {LayoutType::ncsp, ov::element::i32}}, | ||
{{LayoutType::ncsp, outputPrec}}, | ||
impl_desc_type::ref); | ||
} | ||
|
||
bool ROIAlignRotated::created() const { | ||
return getType() == Type::ROIAlignRotated; | ||
} | ||
|
||
bool ROIAlignRotated::needPrepareParams() const { | ||
return false; | ||
} | ||
|
||
void ROIAlignRotated::executeDynamicImpl(dnnl::stream strm) { | ||
execute(strm); | ||
} | ||
|
||
template <ov::element::Type_t OV_TYPE> | ||
void ROIAlignRotated::executeImpl() { | ||
using T = typename ov::element_type_traits<OV_TYPE>::value_type; | ||
|
||
const size_t batch_indices_size = getSrcMemoryAtPort(2)->getShape().getElementsCount(); | ||
|
||
std::vector<int64_t> batch_indices_vec_scaled_up(batch_indices_size); | ||
cpu_convert(getSrcMemoryAtPort(2)->getData(), | ||
batch_indices_vec_scaled_up.data(), | ||
getSrcMemoryAtPort(2)->getPrecision(), | ||
ov::element::i64, | ||
batch_indices_size); | ||
|
||
ov::reference::roi_align<T, ov::reference::roi_policy::ROIAlignRotatedOpDefPolicy>( | ||
getSrcDataAtPortAs<const T>(0), | ||
getSrcDataAtPortAs<const T>(1), | ||
batch_indices_vec_scaled_up.data(), | ||
getDstDataAtPortAs<T>(0), | ||
ov::Shape{getSrcMemoryAtPort(0)->getStaticDims()}, | ||
ov::Shape{getSrcMemoryAtPort(1)->getStaticDims()}, | ||
ov::Shape{getSrcMemoryAtPort(2)->getStaticDims()}, | ||
ov::Shape{getDstMemoryAtPort(0)->getStaticDims()}, | ||
pooledH, | ||
pooledW, | ||
samplingRatio, | ||
spatialScale, | ||
ov::op::v3::ROIAlign::PoolingMode::AVG, | ||
ov::op::v9::ROIAlign::AlignedMode::ASYMMETRIC, | ||
clockwiseMode); | ||
} | ||
|
||
void ROIAlignRotated::execute(dnnl::stream) { | ||
const ov::element::Type type = getOriginalInputPrecisionAtPort(0); | ||
executeImpl<ov::element::f32>(); | ||
|
||
#define CASE(OV_TYPE) \ | ||
case ov::element::OV_TYPE: \ | ||
executeImpl<ov::element::OV_TYPE>(); \ | ||
break; | ||
|
||
switch (type) { | ||
CASE(bf16); | ||
CASE(f16); | ||
CASE(f32); | ||
CASE(f64); | ||
default: | ||
OPENVINO_THROW("[ROIAlignRotated]: Unhandled data type ", type, " in execute()"); | ||
} | ||
#undef CASE | ||
} | ||
|
||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "node.h" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
|
||
class ROIAlignRotated : public Node { | ||
public: | ||
ROIAlignRotated(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context); | ||
|
||
void getSupportedDescriptors() override; | ||
void initSupportedPrimitiveDescriptors() override; | ||
void execute(dnnl::stream strm) override; | ||
bool created() const override; | ||
bool needPrepareParams() const override; | ||
void executeDynamicImpl(dnnl::stream strm) override; | ||
|
||
private: | ||
template <ov::element::Type_t OV_TYPE> | ||
void executeImpl(); | ||
|
||
int pooledH; | ||
int pooledW; | ||
int samplingRatio; | ||
float spatialScale; | ||
bool clockwiseMode; | ||
}; | ||
|
||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters