Skip to content

Commit

Permalink
HFFT family implemented for xarray (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
egpbos committed Oct 24, 2017
1 parent c778fbb commit 67fd6de
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 41 deletions.
118 changes: 92 additions & 26 deletions include/xtensor-fftw/basic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#define XTENSOR_FFTW_BASIC_HPP

#include <xtensor/xarray.hpp>
#include "xtensor/xcomplex.hpp"
#include "xtensor/xeval.hpp"
#include <xtl/xcomplex.hpp>
#include <complex>
#include <tuple>
Expand Down Expand Up @@ -191,6 +193,8 @@ namespace xt {

// Callers for fftw_plan_dft, since they have different call signatures and the
// way shape information is extracted from xtensor differs for different dimensionalities.

// REGULAR FFT N-dim
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
-> std::enable_if_t<dimensional::is_n<dim, fftw_123dim>::value && (fftw_direction != 0), typename fftw_t<input_t>::plan> {
Expand All @@ -207,6 +211,7 @@ namespace xt {
flags);
};

// REGULAR FFT 1D
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
-> std::enable_if_t<dimensional::is_1<dim, fftw_123dim>::value && (fftw_direction != 0), typename fftw_t<input_t>::plan> {
Expand All @@ -222,6 +227,7 @@ namespace xt {
flags);
};

// REGULAR FFT 2D
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
-> std::enable_if_t<dimensional::is_2<dim, fftw_123dim>::value && (fftw_direction != 0), typename fftw_t<input_t>::plan> {
Expand All @@ -237,6 +243,7 @@ namespace xt {
flags);
};

// REGULAR FFT 3D
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
-> std::enable_if_t<dimensional::is_3<dim, fftw_123dim>::value && (fftw_direction != 0), typename fftw_t<input_t>::plan> {
Expand All @@ -252,6 +259,7 @@ namespace xt {
flags);
};

// REAL FFT N-dim
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, 0, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
-> std::enable_if_t<dimensional::is_n<dim, fftw_123dim>::value && (fftw_direction == 0), typename fftw_t<input_t>::plan> {
Expand All @@ -267,6 +275,7 @@ namespace xt {
flags);
};

// REAL FFT 1D
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, 0, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
-> std::enable_if_t<dimensional::is_1<dim, fftw_123dim>::value && (fftw_direction == 0), typename fftw_t<input_t>::plan> {
Expand All @@ -281,6 +290,7 @@ namespace xt {
flags);
};

// REAL FFT 2D
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, 0, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
-> std::enable_if_t<dimensional::is_2<dim, fftw_123dim>::value && (fftw_direction == 0), typename fftw_t<input_t>::plan> {
Expand All @@ -295,6 +305,7 @@ namespace xt {
flags);
};

// REAL FFT 3D
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, 0, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
-> std::enable_if_t<dimensional::is_3<dim, fftw_123dim>::value && (fftw_direction == 0), typename fftw_t<input_t>::plan> {
Expand All @@ -310,8 +321,6 @@ namespace xt {
};




////
// General: xarray templates
////
Expand Down Expand Up @@ -378,6 +387,63 @@ namespace xt {
return output / N_dft;
};

template <
typename input_t, typename output_t, std::size_t dim, int fftw_direction, bool fftw_123dim, bool half_plus_one_out, bool half_plus_one_in,
typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft,
void (&fftw_execute)(typename fftw_t<input_t>::plan), void (&fftw_destroy_plan)(typename fftw_t<input_t>::plan),
typename = std::enable_if_t<
std::is_same< prec_t<input_t>, prec_t<output_t> >::value // input and output precision must be the same
&& std::is_floating_point< prec_t<input_t> >::value // numbers must be float, double or long double
&& (dimensional::is_123<dim, fftw_123dim>::value // dimensionality must match fftw_123dim
|| dimensional::is_n<dim, fftw_123dim>::value)
>
>
inline xt::xarray<output_t> _hfft_(const xt::xarray<input_t, layout_type::row_major> &input) {
auto output_shape = output_shape_from_input(input, half_plus_one_out, half_plus_one_in);
xt::xarray<output_t, layout_type::row_major> output(output_shape);

xt::xarray<input_t, layout_type::row_major> input_conj = xt::conj(input);

auto plan = fftw_plan_dft_caller<dim, fftw_direction, fftw_123dim, input_t, output_t, fftw_plan_dft, half_plus_one_out, half_plus_one_in>(input_conj, output, FFTW_ESTIMATE);
if (plan == nullptr) {
throw std::runtime_error("Plan creation returned nullptr. This usually means FFTW cannot create a plan for the given arguments (e.g. a non-destructive multi-dimensional real FFT is impossible in FFTW).");
}

fftw_execute(plan);
fftw_destroy_plan(plan);
return output;
};

template <
typename input_t, typename output_t, std::size_t dim, int fftw_direction, bool fftw_123dim, bool half_plus_one_out, bool half_plus_one_in,
typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft,
void (&fftw_execute)(typename fftw_t<input_t>::plan), void (&fftw_destroy_plan)(typename fftw_t<input_t>::plan),
typename = std::enable_if_t<
std::is_same< prec_t<input_t>, prec_t<output_t> >::value // input and output precision must be the same
&& std::is_floating_point< prec_t<input_t> >::value // numbers must be float, double or long double
&& (dimensional::is_123<dim, fftw_123dim>::value // dimensionality must match fftw_123dim
|| dimensional::is_n<dim, fftw_123dim>::value)
>
>
inline xt::xarray<output_t> _ihfft_(const xt::xarray<input_t, layout_type::row_major> &input) {
auto output_shape = output_shape_from_input(input, half_plus_one_out, half_plus_one_in);
xt::xarray<output_t, layout_type::row_major> output(output_shape);

auto plan = fftw_plan_dft_caller<dim, fftw_direction, fftw_123dim, input_t, output_t, fftw_plan_dft, half_plus_one_out, half_plus_one_in>(input, output, FFTW_ESTIMATE);
if (plan == nullptr) {
throw std::runtime_error("Plan creation returned nullptr. This usually means FFTW cannot create a plan for the given arguments (e.g. a non-destructive multi-dimensional real FFT is impossible in FFTW).");
}

fftw_execute(plan);
fftw_destroy_plan(plan);

output = xt::conj(output);

auto dft_dimensions = dft_dimensions_from_output(output, half_plus_one_out);
auto N_dft = static_cast<prec_t<output_t> >(std::accumulate(dft_dimensions.begin(), dft_dimensions.end(), 1, std::multiplies<std::size_t>()));
return output / N_dft;
};


////
// General: xtensor templates
Expand Down Expand Up @@ -682,55 +748,55 @@ namespace xt {
////

inline xt::xarray<float> hfft (const xt::xarray<std::complex<float> > &input) {
return _fft_<std::complex<float>, float, 1, 0, true, false, true, fftwf_plan_dft_c2r_1d, fftwf_execute, fftwf_destroy_plan> (input);
return _hfft_<std::complex<float>, float, 1, 0, true, false, true, fftwf_plan_dft_c2r_1d, fftwf_execute, fftwf_destroy_plan> (input);
}

inline xt::xarray<std::complex<float> > ihfft (const xt::xarray<float> &input) {
return _ifft_<float, std::complex<float>, 1, 0, true, true, false, fftwf_plan_dft_r2c_1d, fftwf_execute, fftwf_destroy_plan> (input);
return _ihfft_<float, std::complex<float>, 1, 0, true, true, false, fftwf_plan_dft_r2c_1d, fftwf_execute, fftwf_destroy_plan> (input);
}

inline xt::xarray<double> hfft (const xt::xarray<std::complex<double> > &input) {
return _fft_<std::complex<double>, double, 1, 0, true, false, true, fftw_plan_dft_c2r_1d, fftw_execute, fftw_destroy_plan> (input);
return _hfft_<std::complex<double>, double, 1, 0, true, false, true, fftw_plan_dft_c2r_1d, fftw_execute, fftw_destroy_plan> (input);
}

inline xt::xarray<std::complex<double> > ihfft (const xt::xarray<double> &input) {
return _ifft_<double, std::complex<double>, 1, 0, true, true, false, fftw_plan_dft_r2c_1d, fftw_execute, fftw_destroy_plan> (input);
return _ihfft_<double, std::complex<double>, 1, 0, true, true, false, fftw_plan_dft_r2c_1d, fftw_execute, fftw_destroy_plan> (input);
}

inline xt::xarray<long double> hfft (const xt::xarray<std::complex<long double> > &input) {
return _fft_<std::complex<long double>, long double, 1, 0, true, false, true, fftwl_plan_dft_c2r_1d, fftwl_execute, fftwl_destroy_plan> (input);
return _hfft_<std::complex<long double>, long double, 1, 0, true, false, true, fftwl_plan_dft_c2r_1d, fftwl_execute, fftwl_destroy_plan> (input);
}

inline xt::xarray<std::complex<long double> > ihfft (const xt::xarray<long double> &input) {
return _ifft_<long double, std::complex<long double>, 1, 0, true, true, false, fftwl_plan_dft_r2c_1d, fftwl_execute, fftwl_destroy_plan> (input);
return _ihfft_<long double, std::complex<long double>, 1, 0, true, true, false, fftwl_plan_dft_r2c_1d, fftwl_execute, fftwl_destroy_plan> (input);
}

////
// Hermitian FFT: 2D
////

inline xt::xarray<float> hfft2 (const xt::xarray<std::complex<float> > &input) {
return _fft_<std::complex<float>, float, 2, 0, true, false, true, fftwf_plan_dft_c2r_2d, fftwf_execute, fftwf_destroy_plan> (input);
return _hfft_<std::complex<float>, float, 2, 0, true, false, true, fftwf_plan_dft_c2r_2d, fftwf_execute, fftwf_destroy_plan> (input);
}

inline xt::xarray<std::complex<float> > ihfft2 (const xt::xarray<float> &input) {
return _ifft_<float, std::complex<float>, 2, 0, true, true, false, fftwf_plan_dft_r2c_2d, fftwf_execute, fftwf_destroy_plan> (input);
return _ihfft_<float, std::complex<float>, 2, 0, true, true, false, fftwf_plan_dft_r2c_2d, fftwf_execute, fftwf_destroy_plan> (input);
}

inline xt::xarray<double> hfft2 (const xt::xarray<std::complex<double> > &input) {
return _fft_<std::complex<double>, double, 2, 0, true, false, true, fftw_plan_dft_c2r_2d, fftw_execute, fftw_destroy_plan> (input);
return _hfft_<std::complex<double>, double, 2, 0, true, false, true, fftw_plan_dft_c2r_2d, fftw_execute, fftw_destroy_plan> (input);
}

inline xt::xarray<std::complex<double> > ihfft2 (const xt::xarray<double> &input) {
return _ifft_<double, std::complex<double>, 2, 0, true, true, false, fftw_plan_dft_r2c_2d, fftw_execute, fftw_destroy_plan> (input);
return _ihfft_<double, std::complex<double>, 2, 0, true, true, false, fftw_plan_dft_r2c_2d, fftw_execute, fftw_destroy_plan> (input);
}

inline xt::xarray<long double> hfft2 (const xt::xarray<std::complex<long double> > &input) {
return _fft_<std::complex<long double>, long double, 2, 0, true, false, true, fftwl_plan_dft_c2r_2d, fftwl_execute, fftwl_destroy_plan> (input);
return _hfft_<std::complex<long double>, long double, 2, 0, true, false, true, fftwl_plan_dft_c2r_2d, fftwl_execute, fftwl_destroy_plan> (input);
}

inline xt::xarray<std::complex<long double> > ihfft2 (const xt::xarray<long double> &input) {
return _ifft_<long double, std::complex<long double>, 2, 0, true, true, false, fftwl_plan_dft_r2c_2d, fftwl_execute, fftwl_destroy_plan> (input);
return _ihfft_<long double, std::complex<long double>, 2, 0, true, true, false, fftwl_plan_dft_r2c_2d, fftwl_execute, fftwl_destroy_plan> (input);
}


Expand All @@ -739,27 +805,27 @@ namespace xt {
////

inline xt::xarray<float> hfft3 (const xt::xarray<std::complex<float> > &input) {
return _fft_<std::complex<float>, float, 3, 0, true, false, true, fftwf_plan_dft_c2r_3d, fftwf_execute, fftwf_destroy_plan> (input);
return _hfft_<std::complex<float>, float, 3, 0, true, false, true, fftwf_plan_dft_c2r_3d, fftwf_execute, fftwf_destroy_plan> (input);
}

inline xt::xarray<std::complex<float> > ihfft3 (const xt::xarray<float> &input) {
return _ifft_<float, std::complex<float>, 3, 0, true, true, false, fftwf_plan_dft_r2c_3d, fftwf_execute, fftwf_destroy_plan> (input);
return _ihfft_<float, std::complex<float>, 3, 0, true, true, false, fftwf_plan_dft_r2c_3d, fftwf_execute, fftwf_destroy_plan> (input);
}

inline xt::xarray<double> hfft3 (const xt::xarray<std::complex<double> > &input) {
return _fft_<std::complex<double>, double, 3, 0, true, false, true, fftw_plan_dft_c2r_3d, fftw_execute, fftw_destroy_plan> (input);
return _hfft_<std::complex<double>, double, 3, 0, true, false, true, fftw_plan_dft_c2r_3d, fftw_execute, fftw_destroy_plan> (input);
}

inline xt::xarray<std::complex<double> > ihfft3 (const xt::xarray<double> &input) {
return _ifft_<double, std::complex<double>, 3, 0, true, true, false, fftw_plan_dft_r2c_3d, fftw_execute, fftw_destroy_plan> (input);
return _ihfft_<double, std::complex<double>, 3, 0, true, true, false, fftw_plan_dft_r2c_3d, fftw_execute, fftw_destroy_plan> (input);
}

inline xt::xarray<long double> hfft3 (const xt::xarray<std::complex<long double> > &input) {
return _fft_<std::complex<long double>, long double, 3, 0, true, false, true, fftwl_plan_dft_c2r_3d, fftwl_execute, fftwl_destroy_plan> (input);
return _hfft_<std::complex<long double>, long double, 3, 0, true, false, true, fftwl_plan_dft_c2r_3d, fftwl_execute, fftwl_destroy_plan> (input);
}

inline xt::xarray<std::complex<long double> > ihfft3 (const xt::xarray<long double> &input) {
return _ifft_<long double, std::complex<long double>, 3, 0, true, true, false, fftwl_plan_dft_r2c_3d, fftwl_execute, fftwl_destroy_plan> (input);
return _ihfft_<long double, std::complex<long double>, 3, 0, true, true, false, fftwl_plan_dft_r2c_3d, fftwl_execute, fftwl_destroy_plan> (input);
}


Expand All @@ -769,32 +835,32 @@ namespace xt {

template <std::size_t dim>
inline xt::xarray<float> hfftn (const xt::xarray<std::complex<float> > &input) {
return _fft_<std::complex<float>, float, dim, 0, false, false, true, fftwf_plan_dft_c2r, fftwf_execute, fftwf_destroy_plan> (input);
return _hfft_<std::complex<float>, float, dim, 0, false, false, true, fftwf_plan_dft_c2r, fftwf_execute, fftwf_destroy_plan> (input);
}

template <std::size_t dim>
inline xt::xarray<std::complex<float> > ihfftn (const xt::xarray<float> &input) {
return _ifft_<float, std::complex<float>, dim, 0, false, true, false, fftwf_plan_dft_r2c, fftwf_execute, fftwf_destroy_plan> (input);
return _ihfft_<float, std::complex<float>, dim, 0, false, true, false, fftwf_plan_dft_r2c, fftwf_execute, fftwf_destroy_plan> (input);
}

template <std::size_t dim>
inline xt::xarray<double> hfftn (const xt::xarray<std::complex<double> > &input) {
return _fft_<std::complex<double>, double, dim, 0, false, false, true, fftw_plan_dft_c2r, fftw_execute, fftw_destroy_plan> (input);
return _hfft_<std::complex<double>, double, dim, 0, false, false, true, fftw_plan_dft_c2r, fftw_execute, fftw_destroy_plan> (input);
}

template <std::size_t dim>
inline xt::xarray<std::complex<double> > ihfftn (const xt::xarray<double> &input) {
return _ifft_<double, std::complex<double>, dim, 0, false, true, false, fftw_plan_dft_r2c, fftw_execute, fftw_destroy_plan> (input);
return _ihfft_<double, std::complex<double>, dim, 0, false, true, false, fftw_plan_dft_r2c, fftw_execute, fftw_destroy_plan> (input);
}

template <std::size_t dim>
inline xt::xarray<long double> hfftn (const xt::xarray<std::complex<long double> > &input) {
return _fft_<std::complex<long double>, long double, dim, 0, false, false, true, fftwl_plan_dft_c2r, fftwl_execute, fftwl_destroy_plan> (input);
return _hfft_<std::complex<long double>, long double, dim, 0, false, false, true, fftwl_plan_dft_c2r, fftwl_execute, fftwl_destroy_plan> (input);
}

template <std::size_t dim>
inline xt::xarray<std::complex<long double> > ihfftn (const xt::xarray<long double> &input) {
return _ifft_<long double, std::complex<long double>, dim, 0, false, true, false, fftwl_plan_dft_r2c, fftwl_execute, fftwl_destroy_plan> (input);
return _ihfft_<long double, std::complex<long double>, dim, 0, false, true, false, fftwl_plan_dft_r2c, fftwl_execute, fftwl_destroy_plan> (input);
}

}
Expand Down
Loading

0 comments on commit 67fd6de

Please sign in to comment.