forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[STFT][CPU] Add STFT op to CPU and enable by ref (openvinotoolkit#27137)
### Details: - Enablement of STFT op in CPU by ref impl - Including support for dynamic shapes, and shape related inputs as Parameters ** There is ongoing work to reuse RDFT Executor within STFT CPU impl. Due to number of needed changes, decided to enable by ref first. ### Tickets: - 147161 Related PR: - openvinotoolkit#27186
- Loading branch information
Showing
14 changed files
with
544 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,6 +89,7 @@ enum class Type { | |
ShuffleChannels, | ||
DFT, | ||
RDFT, | ||
STFT, | ||
Math, | ||
CTCLoss, | ||
Bucketize, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "stft.h" | ||
|
||
#include "openvino/core/type.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/stft.hpp" | ||
#include "openvino/reference/stft.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
|
||
bool STFT::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept { | ||
try { | ||
if (op->get_type_info() != op::v15::STFT::get_type_info_static()) { | ||
errorMessage = "Only STFT operation from the opset15 is supported by the CPU plugin."; | ||
return false; | ||
} | ||
} catch (...) { | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
STFT::STFT(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context) | ||
: Node(op, context, NgraphShapeInferFactory(op, PortMask(2, 3))) { | ||
std::string errorMessage; | ||
if (!isSupportedOperation(op, errorMessage)) { | ||
THROW_CPU_NODE_ERR(errorMessage); | ||
} | ||
|
||
const auto stft_op = as_type_ptr<op::v15::STFT>(op); | ||
m_transpose_frames = stft_op->get_transpose_frames(); | ||
|
||
m_is_frame_size_const = is_type<op::v0::Constant>(stft_op->get_input_node_ptr(FRAME_SIZE_IDX)); | ||
m_is_frame_step_const = is_type<op::v0::Constant>(stft_op->get_input_node_ptr(FRAME_STEP_IDX)); | ||
} | ||
|
||
void STFT::getSupportedDescriptors() { | ||
if (getParentEdges().size() != 4) { | ||
THROW_CPU_NODE_ERR("STFT has incorrect number of input edges."); | ||
} | ||
if (getChildEdges().empty()) { | ||
THROW_CPU_NODE_ERR("STFT has incorrect number of output edges."); | ||
} | ||
} | ||
|
||
void STFT::initSupportedPrimitiveDescriptors() { | ||
if (!supportedPrimitiveDescriptors.empty()) | ||
return; | ||
|
||
auto dataPrecision = getOriginalInputPrecisionAtPort(DATA_IDX); | ||
if (!one_of(dataPrecision, ov::element::f32)) { | ||
dataPrecision = ov::element::f32; | ||
} | ||
|
||
std::vector<PortConfigurator> configurators({{LayoutType::ncsp, dataPrecision}, | ||
{LayoutType::ncsp, dataPrecision}, | ||
{LayoutType::ncsp, ov::element::i32}, | ||
{LayoutType::ncsp, ov::element::i32}}); | ||
|
||
addSupportedPrimDesc(configurators, {{LayoutType::ncsp, dataPrecision}}, impl_desc_type::ref_any); | ||
} | ||
|
||
bool STFT::needPrepareParams() const { | ||
return false; | ||
} | ||
|
||
bool STFT::created() const { | ||
return getType() == Type::STFT; | ||
} | ||
|
||
void STFT::execute(dnnl::stream strm) { | ||
ov::reference::stft(getSrcDataAtPortAs<const float>(DATA_IDX), | ||
getSrcDataAtPortAs<const float>(WINDOW_IDX), | ||
getDstDataAtPortAs<float>(0), | ||
ov::Shape{getSrcMemoryAtPort(DATA_IDX)->getStaticDims()}, | ||
ov::Shape{getSrcMemoryAtPort(WINDOW_IDX)->getStaticDims()}, | ||
(getSrcDataAtPortAs<const int32_t>(FRAME_SIZE_IDX))[0], | ||
(getSrcDataAtPortAs<const int32_t>(FRAME_STEP_IDX))[0], | ||
m_transpose_frames); | ||
} | ||
|
||
void STFT::executeDynamicImpl(dnnl::stream strm) { | ||
execute(strm); | ||
} | ||
|
||
bool STFT::needShapeInfer() const { | ||
return !(m_is_frame_size_const && m_is_frame_step_const) || Node::needShapeInfer(); | ||
} | ||
|
||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
|
||
#include "node.h" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
|
||
class STFT : public Node { | ||
public: | ||
STFT(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context); | ||
|
||
void getSupportedDescriptors() override; | ||
void initSupportedPrimitiveDescriptors() override; | ||
bool created() const override; | ||
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept; | ||
bool needPrepareParams() const override; | ||
|
||
void execute(dnnl::stream strm) override; | ||
void executeDynamicImpl(dnnl::stream strm) override; | ||
bool canBeInPlace() const override { | ||
return false; | ||
} | ||
|
||
protected: | ||
bool needShapeInfer() const override; | ||
|
||
private: | ||
/// STFT params | ||
bool m_transpose_frames = false; | ||
|
||
bool m_is_frame_size_const = false; | ||
bool m_is_frame_step_const = false; | ||
|
||
// Input indices | ||
static constexpr size_t DATA_IDX = 0lu; | ||
static constexpr size_t WINDOW_IDX = 1lu; | ||
static constexpr size_t FRAME_SIZE_IDX = 2lu; | ||
static constexpr size_t FRAME_STEP_IDX = 3lu; | ||
}; | ||
|
||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
src/plugins/intel_cpu/tests/functional/shared_tests_instances/single_layer_tests/stft.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "single_op_tests/stft.hpp" | ||
|
||
#include <vector> | ||
|
||
#include "common_test_utils/test_constants.hpp" | ||
|
||
namespace ov { | ||
namespace test { | ||
using ov::test::STFTLayerTest; | ||
|
||
const std::vector<ov::element::Type> data_type = {ov::element::f32, ov::element::bf16}; | ||
const std::vector<ov::element::Type> step_size_type = {ov::element::i32, ov::element::i64}; | ||
|
||
const std::vector<std::vector<InputShape>> input_shapes = { | ||
{ // Static shapes | ||
{{}, {{1, 128}}}, // 1st input | ||
{{}, {{8}}}, // 2nd input | ||
{{}, {{}}}, // 3rd input | ||
{{}, {{}}} // 4th input | ||
}, | ||
{ // Static shapes | ||
{{}, {{2, 226}}}, // 1st input | ||
{{}, {{16}}}, // 2nd input | ||
{{}, {{}}}, // 3rd input | ||
{{}, {{}}} // 4th input | ||
}, | ||
{ // Dynamic dims in the first input shape | ||
{{-1, -1}, {{1, 128}, {2, 226}}}, // 1st input | ||
{{}, {{8}}}, // 2nd input | ||
{{}, {{}}}, // 3rd input | ||
{{}, {{}}} // 4th input | ||
}, | ||
{ // Dynamic dims in the first and second input shape | ||
{{-1, -1}, {{1, 128}, {2, 226}}}, // 1st input | ||
{{-1}, {{8}, {16}}}, // 2nd input | ||
{{}, {{}}}, // 3rd input | ||
{{}, {{}}} // 4th input | ||
}, | ||
{ // Dynamic dims with range in the first and second input shape | ||
{{{2, 4}, {1, 300}}, {{2, 226}, {3, 128}}}, // 1st input | ||
{{{3, 16}}, {{4}, {16}}}, // 2nd input | ||
{{}, {{}}}, // 3rd input | ||
{{}, {{}}} // 4th input | ||
} | ||
}; | ||
|
||
const std::vector<int64_t> frame_size = {16, 24}; | ||
const std::vector<int64_t> step_size = {2, 3, 4}; | ||
|
||
const std::vector<bool> transpose_frames = { | ||
false, | ||
true, | ||
}; | ||
|
||
std::vector<utils::InputLayerType> in_types = {utils::InputLayerType::CONSTANT, utils::InputLayerType::PARAMETER}; | ||
|
||
const auto testCaseStatic = ::testing::Combine(::testing::ValuesIn(input_shapes), | ||
::testing::ValuesIn(frame_size), | ||
::testing::ValuesIn(step_size), | ||
::testing::ValuesIn(transpose_frames), | ||
::testing::ValuesIn(data_type), | ||
::testing::ValuesIn(step_size_type), | ||
::testing::ValuesIn(in_types), | ||
::testing::Values(ov::test::utils::DEVICE_CPU)); | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_STFT_static, STFTLayerTest, testCaseStatic, STFTLayerTest::getTestCaseName); | ||
} // namespace test | ||
} // namespace ov |
Oops, something went wrong.