diff --git a/include/xtensor-fftw/basic.hpp b/include/xtensor-fftw/basic.hpp index 8ba134e..49326fc 100644 --- a/include/xtensor-fftw/basic.hpp +++ b/include/xtensor-fftw/basic.hpp @@ -17,6 +17,8 @@ #define XTENSOR_FFTW_BASIC_HPP #include +#include "xtensor/xcomplex.hpp" +#include "xtensor/xeval.hpp" #include #include #include @@ -191,6 +193,8 @@ namespace xt { // Callers for fftw_plan_dft, since they have different call signatures and the // way shape information is extracted from xtensor differs for different dimensionalities. + + // REGULAR FFT N-dim template ::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in> inline auto fftw_plan_dft_caller(const xt::xarray &input, xt::xarray &output, unsigned int flags) -> std::enable_if_t::value && (fftw_direction != 0), typename fftw_t::plan> { @@ -207,6 +211,7 @@ namespace xt { flags); }; + // REGULAR FFT 1D template ::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in> inline auto fftw_plan_dft_caller(const xt::xarray &input, xt::xarray &output, unsigned int flags) -> std::enable_if_t::value && (fftw_direction != 0), typename fftw_t::plan> { @@ -222,6 +227,7 @@ namespace xt { flags); }; + // REGULAR FFT 2D template ::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in> inline auto fftw_plan_dft_caller(const xt::xarray &input, xt::xarray &output, unsigned int flags) -> std::enable_if_t::value && (fftw_direction != 0), typename fftw_t::plan> { @@ -237,6 +243,7 @@ namespace xt { flags); }; + // REGULAR FFT 3D template ::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in> inline auto fftw_plan_dft_caller(const xt::xarray &input, xt::xarray &output, unsigned int flags) -> std::enable_if_t::value && (fftw_direction != 0), typename fftw_t::plan> { @@ -252,6 +259,7 @@ namespace xt { flags); }; + // REAL FFT N-dim template ::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in> inline auto fftw_plan_dft_caller(const xt::xarray &input, xt::xarray &output, unsigned int flags) -> std::enable_if_t::value && (fftw_direction == 0), typename fftw_t::plan> { @@ -267,6 +275,7 @@ namespace xt { flags); }; + // REAL FFT 1D template ::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in> inline auto fftw_plan_dft_caller(const xt::xarray &input, xt::xarray &output, unsigned int flags) -> std::enable_if_t::value && (fftw_direction == 0), typename fftw_t::plan> { @@ -281,6 +290,7 @@ namespace xt { flags); }; + // REAL FFT 2D template ::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in> inline auto fftw_plan_dft_caller(const xt::xarray &input, xt::xarray &output, unsigned int flags) -> std::enable_if_t::value && (fftw_direction == 0), typename fftw_t::plan> { @@ -295,6 +305,7 @@ namespace xt { flags); }; + // REAL FFT 3D template ::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in> inline auto fftw_plan_dft_caller(const xt::xarray &input, xt::xarray &output, unsigned int flags) -> std::enable_if_t::value && (fftw_direction == 0), typename fftw_t::plan> { @@ -310,8 +321,6 @@ namespace xt { }; - - //// // General: xarray templates //// @@ -378,6 +387,63 @@ namespace xt { return output / N_dft; }; + template < + typename input_t, typename output_t, std::size_t dim, int fftw_direction, bool fftw_123dim, bool half_plus_one_out, bool half_plus_one_in, + typename fftw_plan_dft_signature::type fftw_plan_dft, + void (&fftw_execute)(typename fftw_t::plan), void (&fftw_destroy_plan)(typename fftw_t::plan), + typename = std::enable_if_t< + std::is_same< prec_t, prec_t >::value // input and output precision must be the same + && std::is_floating_point< prec_t >::value // numbers must be float, double or long double + && (dimensional::is_123::value // dimensionality must match fftw_123dim + || dimensional::is_n::value) + > + > + inline xt::xarray _hfft_(const xt::xarray &input) { + auto output_shape = output_shape_from_input(input, half_plus_one_out, half_plus_one_in); + xt::xarray output(output_shape); + + xt::xarray input_conj = xt::conj(input); + + auto plan = fftw_plan_dft_caller(input_conj, output, FFTW_ESTIMATE); + if (plan == nullptr) { + throw std::runtime_error("Plan creation returned nullptr. This usually means FFTW cannot create a plan for the given arguments (e.g. a non-destructive multi-dimensional real FFT is impossible in FFTW)."); + } + + fftw_execute(plan); + fftw_destroy_plan(plan); + return output; + }; + + template < + typename input_t, typename output_t, std::size_t dim, int fftw_direction, bool fftw_123dim, bool half_plus_one_out, bool half_plus_one_in, + typename fftw_plan_dft_signature::type fftw_plan_dft, + void (&fftw_execute)(typename fftw_t::plan), void (&fftw_destroy_plan)(typename fftw_t::plan), + typename = std::enable_if_t< + std::is_same< prec_t, prec_t >::value // input and output precision must be the same + && std::is_floating_point< prec_t >::value // numbers must be float, double or long double + && (dimensional::is_123::value // dimensionality must match fftw_123dim + || dimensional::is_n::value) + > + > + inline xt::xarray _ihfft_(const xt::xarray &input) { + auto output_shape = output_shape_from_input(input, half_plus_one_out, half_plus_one_in); + xt::xarray output(output_shape); + + auto plan = fftw_plan_dft_caller(input, output, FFTW_ESTIMATE); + if (plan == nullptr) { + throw std::runtime_error("Plan creation returned nullptr. This usually means FFTW cannot create a plan for the given arguments (e.g. a non-destructive multi-dimensional real FFT is impossible in FFTW)."); + } + + fftw_execute(plan); + fftw_destroy_plan(plan); + + output = xt::conj(output); + + auto dft_dimensions = dft_dimensions_from_output(output, half_plus_one_out); + auto N_dft = static_cast >(std::accumulate(dft_dimensions.begin(), dft_dimensions.end(), 1, std::multiplies())); + return output / N_dft; + }; + //// // General: xtensor templates @@ -682,27 +748,27 @@ namespace xt { //// inline xt::xarray hfft (const xt::xarray > &input) { - return _fft_, float, 1, 0, true, false, true, fftwf_plan_dft_c2r_1d, fftwf_execute, fftwf_destroy_plan> (input); + return _hfft_, float, 1, 0, true, false, true, fftwf_plan_dft_c2r_1d, fftwf_execute, fftwf_destroy_plan> (input); } inline xt::xarray > ihfft (const xt::xarray &input) { - return _ifft_, 1, 0, true, true, false, fftwf_plan_dft_r2c_1d, fftwf_execute, fftwf_destroy_plan> (input); + return _ihfft_, 1, 0, true, true, false, fftwf_plan_dft_r2c_1d, fftwf_execute, fftwf_destroy_plan> (input); } inline xt::xarray hfft (const xt::xarray > &input) { - return _fft_, double, 1, 0, true, false, true, fftw_plan_dft_c2r_1d, fftw_execute, fftw_destroy_plan> (input); + return _hfft_, double, 1, 0, true, false, true, fftw_plan_dft_c2r_1d, fftw_execute, fftw_destroy_plan> (input); } inline xt::xarray > ihfft (const xt::xarray &input) { - return _ifft_, 1, 0, true, true, false, fftw_plan_dft_r2c_1d, fftw_execute, fftw_destroy_plan> (input); + return _ihfft_, 1, 0, true, true, false, fftw_plan_dft_r2c_1d, fftw_execute, fftw_destroy_plan> (input); } inline xt::xarray hfft (const xt::xarray > &input) { - return _fft_, long double, 1, 0, true, false, true, fftwl_plan_dft_c2r_1d, fftwl_execute, fftwl_destroy_plan> (input); + return _hfft_, long double, 1, 0, true, false, true, fftwl_plan_dft_c2r_1d, fftwl_execute, fftwl_destroy_plan> (input); } inline xt::xarray > ihfft (const xt::xarray &input) { - return _ifft_, 1, 0, true, true, false, fftwl_plan_dft_r2c_1d, fftwl_execute, fftwl_destroy_plan> (input); + return _ihfft_, 1, 0, true, true, false, fftwl_plan_dft_r2c_1d, fftwl_execute, fftwl_destroy_plan> (input); } //// @@ -710,27 +776,27 @@ namespace xt { //// inline xt::xarray hfft2 (const xt::xarray > &input) { - return _fft_, float, 2, 0, true, false, true, fftwf_plan_dft_c2r_2d, fftwf_execute, fftwf_destroy_plan> (input); + return _hfft_, float, 2, 0, true, false, true, fftwf_plan_dft_c2r_2d, fftwf_execute, fftwf_destroy_plan> (input); } inline xt::xarray > ihfft2 (const xt::xarray &input) { - return _ifft_, 2, 0, true, true, false, fftwf_plan_dft_r2c_2d, fftwf_execute, fftwf_destroy_plan> (input); + return _ihfft_, 2, 0, true, true, false, fftwf_plan_dft_r2c_2d, fftwf_execute, fftwf_destroy_plan> (input); } inline xt::xarray hfft2 (const xt::xarray > &input) { - return _fft_, double, 2, 0, true, false, true, fftw_plan_dft_c2r_2d, fftw_execute, fftw_destroy_plan> (input); + return _hfft_, double, 2, 0, true, false, true, fftw_plan_dft_c2r_2d, fftw_execute, fftw_destroy_plan> (input); } inline xt::xarray > ihfft2 (const xt::xarray &input) { - return _ifft_, 2, 0, true, true, false, fftw_plan_dft_r2c_2d, fftw_execute, fftw_destroy_plan> (input); + return _ihfft_, 2, 0, true, true, false, fftw_plan_dft_r2c_2d, fftw_execute, fftw_destroy_plan> (input); } inline xt::xarray hfft2 (const xt::xarray > &input) { - return _fft_, long double, 2, 0, true, false, true, fftwl_plan_dft_c2r_2d, fftwl_execute, fftwl_destroy_plan> (input); + return _hfft_, long double, 2, 0, true, false, true, fftwl_plan_dft_c2r_2d, fftwl_execute, fftwl_destroy_plan> (input); } inline xt::xarray > ihfft2 (const xt::xarray &input) { - return _ifft_, 2, 0, true, true, false, fftwl_plan_dft_r2c_2d, fftwl_execute, fftwl_destroy_plan> (input); + return _ihfft_, 2, 0, true, true, false, fftwl_plan_dft_r2c_2d, fftwl_execute, fftwl_destroy_plan> (input); } @@ -739,27 +805,27 @@ namespace xt { //// inline xt::xarray hfft3 (const xt::xarray > &input) { - return _fft_, float, 3, 0, true, false, true, fftwf_plan_dft_c2r_3d, fftwf_execute, fftwf_destroy_plan> (input); + return _hfft_, float, 3, 0, true, false, true, fftwf_plan_dft_c2r_3d, fftwf_execute, fftwf_destroy_plan> (input); } inline xt::xarray > ihfft3 (const xt::xarray &input) { - return _ifft_, 3, 0, true, true, false, fftwf_plan_dft_r2c_3d, fftwf_execute, fftwf_destroy_plan> (input); + return _ihfft_, 3, 0, true, true, false, fftwf_plan_dft_r2c_3d, fftwf_execute, fftwf_destroy_plan> (input); } inline xt::xarray hfft3 (const xt::xarray > &input) { - return _fft_, double, 3, 0, true, false, true, fftw_plan_dft_c2r_3d, fftw_execute, fftw_destroy_plan> (input); + return _hfft_, double, 3, 0, true, false, true, fftw_plan_dft_c2r_3d, fftw_execute, fftw_destroy_plan> (input); } inline xt::xarray > ihfft3 (const xt::xarray &input) { - return _ifft_, 3, 0, true, true, false, fftw_plan_dft_r2c_3d, fftw_execute, fftw_destroy_plan> (input); + return _ihfft_, 3, 0, true, true, false, fftw_plan_dft_r2c_3d, fftw_execute, fftw_destroy_plan> (input); } inline xt::xarray hfft3 (const xt::xarray > &input) { - return _fft_, long double, 3, 0, true, false, true, fftwl_plan_dft_c2r_3d, fftwl_execute, fftwl_destroy_plan> (input); + return _hfft_, long double, 3, 0, true, false, true, fftwl_plan_dft_c2r_3d, fftwl_execute, fftwl_destroy_plan> (input); } inline xt::xarray > ihfft3 (const xt::xarray &input) { - return _ifft_, 3, 0, true, true, false, fftwl_plan_dft_r2c_3d, fftwl_execute, fftwl_destroy_plan> (input); + return _ihfft_, 3, 0, true, true, false, fftwl_plan_dft_r2c_3d, fftwl_execute, fftwl_destroy_plan> (input); } @@ -769,32 +835,32 @@ namespace xt { template inline xt::xarray hfftn (const xt::xarray > &input) { - return _fft_, float, dim, 0, false, false, true, fftwf_plan_dft_c2r, fftwf_execute, fftwf_destroy_plan> (input); + return _hfft_, float, dim, 0, false, false, true, fftwf_plan_dft_c2r, fftwf_execute, fftwf_destroy_plan> (input); } template inline xt::xarray > ihfftn (const xt::xarray &input) { - return _ifft_, dim, 0, false, true, false, fftwf_plan_dft_r2c, fftwf_execute, fftwf_destroy_plan> (input); + return _ihfft_, dim, 0, false, true, false, fftwf_plan_dft_r2c, fftwf_execute, fftwf_destroy_plan> (input); } template inline xt::xarray hfftn (const xt::xarray > &input) { - return _fft_, double, dim, 0, false, false, true, fftw_plan_dft_c2r, fftw_execute, fftw_destroy_plan> (input); + return _hfft_, double, dim, 0, false, false, true, fftw_plan_dft_c2r, fftw_execute, fftw_destroy_plan> (input); } template inline xt::xarray > ihfftn (const xt::xarray &input) { - return _ifft_, dim, 0, false, true, false, fftw_plan_dft_r2c, fftw_execute, fftw_destroy_plan> (input); + return _ihfft_, dim, 0, false, true, false, fftw_plan_dft_r2c, fftw_execute, fftw_destroy_plan> (input); } template inline xt::xarray hfftn (const xt::xarray > &input) { - return _fft_, long double, dim, 0, false, false, true, fftwl_plan_dft_c2r, fftwl_execute, fftwl_destroy_plan> (input); + return _hfft_, long double, dim, 0, false, false, true, fftwl_plan_dft_c2r, fftwl_execute, fftwl_destroy_plan> (input); } template inline xt::xarray > ihfftn (const xt::xarray &input) { - return _ifft_, dim, 0, false, true, false, fftwl_plan_dft_r2c, fftwl_execute, fftwl_destroy_plan> (input); + return _ihfft_, dim, 0, false, true, false, fftwl_plan_dft_r2c, fftwl_execute, fftwl_destroy_plan> (input); } } diff --git a/test/basic_interface.cpp b/test/basic_interface.cpp index f93d40a..62f5507 100644 --- a/test/basic_interface.cpp +++ b/test/basic_interface.cpp @@ -269,46 +269,41 @@ TYPED_TEST(TransformAndInvert, realFFT_4D_xtensor) { TYPED_TEST(TransformAndInvert, hermFFT_1D_xarray) { xt::xarray, xt::layout_type::row_major> a = generate_hermitian_data(data_size); - xt::xarray, xt::layout_type::row_major> a_conj = xt::conj(a); - auto a_fourier = xt::fftw::hfft(a_conj); + auto a_fourier = xt::fftw::hfft(a); std::cout << "fourier transform of input before ifft (which is destructive!): " << a_fourier << std::endl; - auto should_be_a = xt::conj(xt::fftw::ihfft(a_fourier)); + auto should_be_a = xt::fftw::ihfft(a_fourier); assert_results_complex(a, a_fourier, should_be_a); } TYPED_TEST(TransformAndInvert, hermFFT_2D_xarray) { xt::xarray, xt::layout_type::row_major> a = generate_hermitian_data(data_size); - xt::xarray, xt::layout_type::row_major> a_conj = xt::conj(a); - auto a_fourier = xt::fftw::hfft2(a_conj); + auto a_fourier = xt::fftw::hfft2(a); std::cout << "fourier transform of input before ifft (which is destructive!): " << a_fourier << std::endl; - auto should_be_a = xt::conj(xt::fftw::ihfft2(a_fourier)); + auto should_be_a = xt::fftw::ihfft2(a_fourier); assert_results_complex(a, a_fourier, should_be_a); } TYPED_TEST(TransformAndInvert, hermFFT_3D_xarray) { xt::xarray, xt::layout_type::row_major> a = generate_hermitian_data(data_size); - xt::xarray, xt::layout_type::row_major> a_conj = xt::conj(a); - auto a_fourier = xt::fftw::hfft3(a_conj); + auto a_fourier = xt::fftw::hfft3(a); std::cout << "fourier transform of input before ifft (which is destructive!): " << a_fourier << std::endl; - auto should_be_a = xt::conj(xt::fftw::ihfft3(a_fourier)); + auto should_be_a = xt::fftw::ihfft3(a_fourier); assert_results_complex(a, a_fourier, should_be_a); } TYPED_TEST(TransformAndInvert, hermFFT_nD_n_equals_4_xarray) { xt::xarray, xt::layout_type::row_major> a = generate_hermitian_data, xt::fftw::ihfftn<4>>(data_size); - xt::xarray, xt::layout_type::row_major> a_conj = xt::conj(a); - auto a_fourier = xt::fftw::hfftn<4>(a_conj); + auto a_fourier = xt::fftw::hfftn<4>(a); std::cout << "fourier transform of input before ifft (which is destructive!): " << a_fourier << std::endl; - auto should_be_a = xt::conj(xt::fftw::ihfftn<4>(a_fourier)); + auto should_be_a = xt::fftw::ihfftn<4>(a_fourier); assert_results_complex(a, a_fourier, should_be_a); } TYPED_TEST(TransformAndInvert, hermFFT_nD_n_equals_1_xarray) { xt::xarray, xt::layout_type::row_major> a = generate_hermitian_data(data_size); - xt::xarray, xt::layout_type::row_major> a_conj = xt::conj(a); - auto a_fourier = xt::fftw::hfftn<1>(a_conj); + auto a_fourier = xt::fftw::hfftn<1>(a); std::cout << "fourier transform of input before ifft (which is destructive!): " << a_fourier << std::endl; - auto should_be_a = xt::conj(xt::fftw::ihfftn<1>(a_fourier)); + auto should_be_a = xt::fftw::ihfftn<1>(a_fourier); assert_results_complex(a, a_fourier, should_be_a); }