Skip to content

Commit

Permalink
[STFT][CPU] Add STFT op to CPU and enable by ref (openvinotoolkit#27137)
Browse files Browse the repository at this point in the history
### Details:
 - Enablement of STFT op in CPU by ref impl
- Including support for dynamic shapes, and shape related inputs as
Parameters

** There is ongoing work to reuse RDFT Executor within STFT CPU impl.
Due to number of needed changes, decided to enable by ref first.

### Tickets:
 - 147161
 
 Related PR:
 - openvinotoolkit#27186
  • Loading branch information
mitruska authored Oct 25, 2024
1 parent 2fae7bc commit 2d3289e
Show file tree
Hide file tree
Showing 14 changed files with 544 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/core/shape_inference/include/stft_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ std::vector<TRShape> shape_infer(const STFT* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
using TDim = typename TRShape::value_type;
using TDimVal = typename TDim::value_type;

NODE_VALIDATION_CHECK(op, input_shapes.size() == 4);

const auto& signal_shape = input_shapes[0];
Expand Down Expand Up @@ -46,15 +48,19 @@ std::vector<TRShape> shape_infer(const STFT* op,
if (signal_shape.rank().is_dynamic()) {
return {signal_shape};
} else if (!frame_size || !frame_step) {
return {TRShape{signal_shape[0], -1, -1, 2}};
return {TRShape{signal_shape[0], TDim(ov::util::dim::inf_bound), TDim(ov::util::dim::inf_bound), 2}};
}

const auto& frame_size_val = (*frame_size)[0];
const auto& frame_step_val = (*frame_step)[0];

const bool is_frame_size_in_range =
0 < frame_size_val &&
(signal_shape[1].is_static() ? static_cast<TDimVal>(frame_size_val) <= signal_shape[1].get_length()
: frame_size_val <= signal_shape[1].get_interval().get_max_val());
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
0 < frame_size_val && frame_size_val < signal_shape[1].get_interval().get_max_val(),
is_frame_size_in_range,
"Provided frame size is ",
frame_size_val,
" but must be in range [1, ",
Expand All @@ -68,16 +74,18 @@ std::vector<TRShape> shape_infer(const STFT* op,
frame_step_val,
" but must be greater than zero.");

const bool is_win_shape_correct =
window_shape.is_dynamic() || (TDimVal{0} < window_shape[0].get_length() &&
window_shape[0].get_length() <= static_cast<TDimVal>(frame_size_val));
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
window_shape.is_dynamic() ||
(0 < window_shape[0].get_length() && window_shape[0].get_length() <= frame_size_val),
is_win_shape_correct,
"Window input dimension must be in range [1, ",
frame_size_val,
"].");

const auto& batch_dim = signal_shape[0];
const TDim frame_size_dim = TDim{frame_size_val};
const TDim frame_size_dim = static_cast<TDim>(frame_size_val);
const TDim signal_frame_size_diff = signal_shape[1] - frame_size_dim;
TDim fft_samples_dim = (frame_size_val / 2) + 1;

Expand Down
1 change: 1 addition & 0 deletions src/core/tests/type_prop/stft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ INSTANTIATE_TEST_SUITE_P(
type_prop_stft_shape,
TypePropSTFTTestP,
testing::Values(
std::make_tuple(PartialShape{1, 16}, PartialShape{16}, 16, 16, true, PartialShape{1, 9, 1, 2}),
std::make_tuple(PartialShape{1, 48}, PartialShape{16}, 16, 16, true, PartialShape{1, 9, 3, 2}),
std::make_tuple(PartialShape{1, 48}, PartialShape{16}, 16, 16, false, PartialShape{1, 3, 9, 2}),
std::make_tuple(PartialShape{2, 48}, PartialShape{8}, 16, 4, true, PartialShape{2, 9, 9, 2}),
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"IDFT", Type::DFT},
{"RDFT", Type::RDFT},
{"IRDFT", Type::RDFT},
{"STFT", Type::STFT},
{"Abs", Type::Math},
{"Acos", Type::Math},
{"Acosh", Type::Math},
Expand Down Expand Up @@ -342,6 +343,7 @@ std::string NameFromType(const Type type) {
CASE(ShuffleChannels);
CASE(DFT);
CASE(RDFT);
CASE(STFT);
CASE(Math);
CASE(CTCLoss);
CASE(Bucketize);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ enum class Type {
ShuffleChannels,
DFT,
RDFT,
STFT,
Math,
CTCLoss,
Bucketize,
Expand Down
97 changes: 97 additions & 0 deletions src/plugins/intel_cpu/src/nodes/stft.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "stft.h"

#include "openvino/core/type.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/stft.hpp"
#include "openvino/reference/stft.hpp"

namespace ov {
namespace intel_cpu {
namespace node {

bool STFT::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
if (op->get_type_info() != op::v15::STFT::get_type_info_static()) {
errorMessage = "Only STFT operation from the opset15 is supported by the CPU plugin.";
return false;
}
} catch (...) {
return false;
}
return true;
}

STFT::STFT(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context)
: Node(op, context, NgraphShapeInferFactory(op, PortMask(2, 3))) {
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
THROW_CPU_NODE_ERR(errorMessage);
}

const auto stft_op = as_type_ptr<op::v15::STFT>(op);
m_transpose_frames = stft_op->get_transpose_frames();

m_is_frame_size_const = is_type<op::v0::Constant>(stft_op->get_input_node_ptr(FRAME_SIZE_IDX));
m_is_frame_step_const = is_type<op::v0::Constant>(stft_op->get_input_node_ptr(FRAME_STEP_IDX));
}

void STFT::getSupportedDescriptors() {
if (getParentEdges().size() != 4) {
THROW_CPU_NODE_ERR("STFT has incorrect number of input edges.");
}
if (getChildEdges().empty()) {
THROW_CPU_NODE_ERR("STFT has incorrect number of output edges.");
}
}

void STFT::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

auto dataPrecision = getOriginalInputPrecisionAtPort(DATA_IDX);
if (!one_of(dataPrecision, ov::element::f32)) {
dataPrecision = ov::element::f32;
}

std::vector<PortConfigurator> configurators({{LayoutType::ncsp, dataPrecision},
{LayoutType::ncsp, dataPrecision},
{LayoutType::ncsp, ov::element::i32},
{LayoutType::ncsp, ov::element::i32}});

addSupportedPrimDesc(configurators, {{LayoutType::ncsp, dataPrecision}}, impl_desc_type::ref_any);
}

bool STFT::needPrepareParams() const {
return false;
}

bool STFT::created() const {
return getType() == Type::STFT;
}

void STFT::execute(dnnl::stream strm) {
ov::reference::stft(getSrcDataAtPortAs<const float>(DATA_IDX),
getSrcDataAtPortAs<const float>(WINDOW_IDX),
getDstDataAtPortAs<float>(0),
ov::Shape{getSrcMemoryAtPort(DATA_IDX)->getStaticDims()},
ov::Shape{getSrcMemoryAtPort(WINDOW_IDX)->getStaticDims()},
(getSrcDataAtPortAs<const int32_t>(FRAME_SIZE_IDX))[0],
(getSrcDataAtPortAs<const int32_t>(FRAME_STEP_IDX))[0],
m_transpose_frames);
}

void STFT::executeDynamicImpl(dnnl::stream strm) {
execute(strm);
}

bool STFT::needShapeInfer() const {
return !(m_is_frame_size_const && m_is_frame_step_const) || Node::needShapeInfer();
}

} // namespace node
} // namespace intel_cpu
} // namespace ov
50 changes: 50 additions & 0 deletions src/plugins/intel_cpu/src/nodes/stft.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <string>

#include "node.h"

namespace ov {
namespace intel_cpu {
namespace node {

class STFT : public Node {
public:
STFT(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context);

void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override;
bool created() const override;
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
bool needPrepareParams() const override;

void execute(dnnl::stream strm) override;
void executeDynamicImpl(dnnl::stream strm) override;
bool canBeInPlace() const override {
return false;
}

protected:
bool needShapeInfer() const override;

private:
/// STFT params
bool m_transpose_frames = false;

bool m_is_frame_size_const = false;
bool m_is_frame_step_const = false;

// Input indices
static constexpr size_t DATA_IDX = 0lu;
static constexpr size_t WINDOW_IDX = 1lu;
static constexpr size_t FRAME_SIZE_IDX = 2lu;
static constexpr size_t FRAME_STEP_IDX = 3lu;
};

} // namespace node
} // namespace intel_cpu
} // namespace ov
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/nodes_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
#include "nodes/space_to_batch.h"
#include "nodes/space_to_depth.h"
#include "nodes/split.h"
#include "nodes/stft.h"
#include "nodes/strided_slice.h"
#include "nodes/subgraph.h"
#include "nodes/tensoriterator.h"
Expand Down Expand Up @@ -214,6 +215,7 @@ Node::NodesFactory::NodesFactory() : Factory("NodesFactory") {
INTEL_CPU_NODE(RegionYolo, Type::RegionYolo);
INTEL_CPU_NODE(DFT, Type::DFT);
INTEL_CPU_NODE(RDFT, Type::RDFT);
INTEL_CPU_NODE(STFT, Type::STFT);
INTEL_CPU_NODE(ExtractImagePatches, Type::ExtractImagePatches);
INTEL_CPU_NODE(Subgraph, Type::Subgraph);
INTEL_CPU_NODE(Composite, Type::SubModel);
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
#include "split_shape_inference.hpp"
#include "squeeze_shape_inference.hpp"
#include "static_shape.hpp"
#include "stft_shape_inference.hpp"
#include "strided_slice_shape_inference.hpp"
#include "string_tensor_pack_shape_inference.hpp"
#include "string_tensor_unpack_shape_inference.hpp"
Expand Down Expand Up @@ -414,6 +415,7 @@ const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{
_OV_OP_SHAPE_INFER_MASK_REG(op::v15::Col2Im, ShapeInferTA, util::bit::mask(1, 2)),
_OV_OP_SHAPE_INFER_MASK_REG(op::v15::ScatterNDUpdate, ShapeInferTA, util::bit::mask()),
_OV_OP_SHAPE_INFER_MASK_REG(opset15::SliceScatter, ShapeInferTA, util::bit::mask(2, 3, 4, 5)),
_OV_OP_SHAPE_INFER_MASK_REG(op::v15::STFT, ShapeInferTA, util::bit::mask(2, 3)),
// opset14
_OV_OP_SHAPE_INFER_MASK_REG(opset14::Inverse, ShapeInferTA, util::bit::mask()),
_OV_OP_SHAPE_INFER_MASK_REG(opset14::MaxPool, ShapeInferPaddingTA, util::bit::mask()),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "single_op_tests/stft.hpp"

#include <vector>

#include "common_test_utils/test_constants.hpp"

namespace ov {
namespace test {
using ov::test::STFTLayerTest;

const std::vector<ov::element::Type> data_type = {ov::element::f32, ov::element::bf16};
const std::vector<ov::element::Type> step_size_type = {ov::element::i32, ov::element::i64};

const std::vector<std::vector<InputShape>> input_shapes = {
{ // Static shapes
{{}, {{1, 128}}}, // 1st input
{{}, {{8}}}, // 2nd input
{{}, {{}}}, // 3rd input
{{}, {{}}} // 4th input
},
{ // Static shapes
{{}, {{2, 226}}}, // 1st input
{{}, {{16}}}, // 2nd input
{{}, {{}}}, // 3rd input
{{}, {{}}} // 4th input
},
{ // Dynamic dims in the first input shape
{{-1, -1}, {{1, 128}, {2, 226}}}, // 1st input
{{}, {{8}}}, // 2nd input
{{}, {{}}}, // 3rd input
{{}, {{}}} // 4th input
},
{ // Dynamic dims in the first and second input shape
{{-1, -1}, {{1, 128}, {2, 226}}}, // 1st input
{{-1}, {{8}, {16}}}, // 2nd input
{{}, {{}}}, // 3rd input
{{}, {{}}} // 4th input
},
{ // Dynamic dims with range in the first and second input shape
{{{2, 4}, {1, 300}}, {{2, 226}, {3, 128}}}, // 1st input
{{{3, 16}}, {{4}, {16}}}, // 2nd input
{{}, {{}}}, // 3rd input
{{}, {{}}} // 4th input
}
};

const std::vector<int64_t> frame_size = {16, 24};
const std::vector<int64_t> step_size = {2, 3, 4};

const std::vector<bool> transpose_frames = {
false,
true,
};

std::vector<utils::InputLayerType> in_types = {utils::InputLayerType::CONSTANT, utils::InputLayerType::PARAMETER};

const auto testCaseStatic = ::testing::Combine(::testing::ValuesIn(input_shapes),
::testing::ValuesIn(frame_size),
::testing::ValuesIn(step_size),
::testing::ValuesIn(transpose_frames),
::testing::ValuesIn(data_type),
::testing::ValuesIn(step_size_type),
::testing::ValuesIn(in_types),
::testing::Values(ov::test::utils::DEVICE_CPU));

INSTANTIATE_TEST_SUITE_P(smoke_STFT_static, STFTLayerTest, testCaseStatic, STFTLayerTest::getTestCaseName);
} // namespace test
} // namespace ov
Loading

0 comments on commit 2d3289e

Please sign in to comment.