diff --git a/src/frontends/pytorch/src/op/stft.cpp b/src/frontends/pytorch/src/op/stft.cpp index b7e4858c2f8fcc..d1fe4f9f15828b 100644 --- a/src/frontends/pytorch/src/op/stft.cpp +++ b/src/frontends/pytorch/src/op/stft.cpp @@ -10,6 +10,7 @@ #include "openvino/op/convert_like.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/sqrt.hpp" #include "openvino/op/unsqueeze.hpp" #include "utils.hpp" @@ -66,8 +67,6 @@ OutputVector translate_stft(const NodeContext& context) { if (!context.input_is_none(5)) { normalized = context.const_input(5); } - PYTORCH_OP_CONVERSION_CHECK(!normalized, - "aten::stft conversion is currently supported with normalized=False only."); bool onesided = true; if (!context.input_is_none(6)) { @@ -85,7 +84,15 @@ OutputVector translate_stft(const NodeContext& context) { // Perform STFT constexpr bool transpose_frames = true; auto stft = context.mark_node(std::make_shared(input, window, n_fft, hop_length, transpose_frames)); - return {stft}; + + if (normalized) { + const auto nfft_convert = context.mark_node(std::make_shared(n_fft, stft)); + const auto divisor = context.mark_node(std::make_shared(nfft_convert)); + const auto norm_stft = context.mark_node(std::make_shared(stft, divisor)); + return {norm_stft}; + } else { + return {stft}; + } }; } // namespace op } // namespace pytorch diff --git a/tests/layer_tests/pytorch_tests/test_stft.py b/tests/layer_tests/pytorch_tests/test_stft.py index 832a624da65626..d0ff347e58602f 100644 --- a/tests/layer_tests/pytorch_tests/test_stft.py +++ b/tests/layer_tests/pytorch_tests/test_stft.py @@ -24,16 +24,17 @@ def _prepare_input(self, win_length, signal_shape, rand_data=False, out_dtype="f return (signal, window.astype(out_dtype)) - def create_model(self, n_fft, hop_length, win_length): + def create_model(self, n_fft, hop_length, win_length, normalized): import torch class aten_stft(torch.nn.Module): - def __init__(self, n_fft, hop_length, win_length): + def __init__(self, n_fft, hop_length, win_length, normalized): super(aten_stft, self).__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length + self.normalized = normalized def forward(self, x, window): return torch.stft( @@ -44,14 +45,14 @@ def forward(self, x, window): window=window, center=False, pad_mode="reflect", - normalized=False, + normalized=self.normalized, onesided=True, return_complex=False, ) ref_net = None - return aten_stft(n_fft, hop_length, win_length), ref_net, "aten::stft" + return aten_stft(n_fft, hop_length, win_length, normalized), ref_net, "aten::stft" @pytest.mark.nightly @pytest.mark.precommit @@ -64,10 +65,11 @@ def forward(self, x, window): [24, 32, 20], [128, 128, 128], ]) - def test_stft(self, n_fft, hop_length, window_size, signal_shape, ie_device, precision, ir_version, trace_model): + @pytest.mark.parametrize(("normalized"), [True, False]) + def test_stft(self, n_fft, hop_length, window_size, signal_shape, normalized, ie_device, precision, ir_version, trace_model): if ie_device == "GPU": pytest.xfail(reason="STFT op is not supported on GPU yet") - self._test(*self.create_model(n_fft, hop_length, window_size), ie_device, precision, + self._test(*self.create_model(n_fft, hop_length, window_size, normalized), ie_device, precision, ir_version, kwargs_to_prepare_input={"win_length": window_size, "signal_shape": signal_shape}, trace_model=trace_model) @@ -125,8 +127,8 @@ def forward(self, x): [16, None, 16, False, "reflect", False, True, False], # hop_length None [16, None, None, False, "reflect", False, True, False], # hop & win length None [16, 4, None, False, "reflect", False, True, False], # win_length None - # Unsupported cases: [16, 4, 16, False, "reflect", True, True, False], # normalized True + # Unsupported cases: [16, 4, 16, False, "reflect", False, False, False], # onesided False [16, 4, 16, False, "reflect", False, True, True], # reutrn_complex True ]) @@ -138,10 +140,6 @@ def test_stft_not_supported_attrs(self, n_fft, hop_length, win_length, center, p pytest.xfail( reason="torch stft uses list() for `center` subgrpah before aten::stft, that leads to error: No conversion rule found for operations: aten::list") - if normalized is True: - pytest.xfail( - reason="aten::stft conversion is currently supported with normalized=False only") - if onesided is False: pytest.xfail( reason="aten::stft conversion is currently supported with onesided=True only")