From a1c63535a07a074e196436635fb3e354bf49670a Mon Sep 17 00:00:00 2001 From: "Vandenplas, Jeremie" Date: Sun, 19 Jan 2020 20:11:49 +0100 Subject: [PATCH] modification to have the same behaviour as Fortran sum --- src/stdlib_experimental_stat.f90 | 26 +++++-- src/stdlib_experimental_stat.fypp.f90 | 12 ++- src/stdlib_experimental_stat_mean.f90 | 89 +++++++++++----------- src/stdlib_experimental_stat_mean.fypp.f90 | 31 ++++---- src/tests/stat/test_mean.f90 | 10 ++- 5 files changed, 98 insertions(+), 70 deletions(-) diff --git a/src/stdlib_experimental_stat.f90 b/src/stdlib_experimental_stat.f90 index 1c11634e3..6a399313d 100644 --- a/src/stdlib_experimental_stat.f90 +++ b/src/stdlib_experimental_stat.f90 @@ -22,20 +22,34 @@ module function mean_1_qp_qp(mat) result(res) real(qp) ::res end function mean_1_qp_qp + module function mean_2_all_sp_sp(mat) result(res) + real(sp), intent(in) :: mat(:,:) + real(sp) ::res + end function mean_2_all_sp_sp + module function mean_2_all_dp_dp(mat) result(res) + real(dp), intent(in) :: mat(:,:) + real(dp) ::res + end function mean_2_all_dp_dp + module function mean_2_all_qp_qp(mat) result(res) + real(qp), intent(in) :: mat(:,:) + real(qp) ::res + end function mean_2_all_qp_qp + + module function mean_2_sp_sp(mat, dim) result(res) real(sp), intent(in) :: mat(:,:) - integer, intent(in), optional :: dim - real(sp), allocatable ::res(:) + integer, intent(in) :: dim + real(sp) :: res(size(mat)/size(mat, dim)) end function mean_2_sp_sp module function mean_2_dp_dp(mat, dim) result(res) real(dp), intent(in) :: mat(:,:) - integer, intent(in), optional :: dim - real(dp), allocatable ::res(:) + integer, intent(in) :: dim + real(dp) :: res(size(mat)/size(mat, dim)) end function mean_2_dp_dp module function mean_2_qp_qp(mat, dim) result(res) real(qp), intent(in) :: mat(:,:) - integer, intent(in), optional :: dim - real(qp), allocatable ::res(:) + integer, intent(in) :: dim + real(qp) :: res(size(mat)/size(mat, dim)) end function mean_2_qp_qp end interface diff --git a/src/stdlib_experimental_stat.fypp.f90 b/src/stdlib_experimental_stat.fypp.f90 index ae964f45b..2db139b03 100644 --- a/src/stdlib_experimental_stat.fypp.f90 +++ b/src/stdlib_experimental_stat.fypp.f90 @@ -20,11 +20,19 @@ module function mean_1_${k1}$_${k1}$(mat) result(res) end function mean_1_${k1}$_${k1}$ #:endfor +#:for i1, k1, t1 in ikt + module function mean_2_all_${k1}$_${k1}$(mat) result(res) + ${t1}$, intent(in) :: mat(:,:) + ${t1}$ ::res + end function mean_2_all_${k1}$_${k1}$ +#:endfor + + #:for i1, k1, t1 in ikt module function mean_2_${k1}$_${k1}$(mat, dim) result(res) ${t1}$, intent(in) :: mat(:,:) - integer, intent(in), optional :: dim - ${t1}$, allocatable ::res(:) + integer, intent(in) :: dim + ${t1}$ :: res(size(mat)/size(mat, dim)) end function mean_2_${k1}$_${k1}$ #:endfor end interface diff --git a/src/stdlib_experimental_stat_mean.f90 b/src/stdlib_experimental_stat_mean.f90 index 56527d514..3c3548659 100644 --- a/src/stdlib_experimental_stat_mean.f90 +++ b/src/stdlib_experimental_stat_mean.f90 @@ -29,78 +29,79 @@ module function mean_1_qp_qp(mat) result(res) end function mean_1_qp_qp -module function mean_2_sp_sp(mat, dim) result(res) +module function mean_2_all_sp_sp(mat) result(res) real(sp), intent(in) :: mat(:,:) - integer, intent(in), optional :: dim - real(sp), allocatable ::res(:) + real(sp) ::res - integer :: i - integer :: dim_ + res = sum(mat) / real(size(mat), sp) + +end function mean_2_all_sp_sp +module function mean_2_all_dp_dp(mat) result(res) + real(dp), intent(in) :: mat(:,:) + real(dp) ::res + + res = sum(mat) / real(size(mat), dp) - dim_ = optval(dim, 1) +end function mean_2_all_dp_dp +module function mean_2_all_qp_qp(mat) result(res) + real(qp), intent(in) :: mat(:,:) + real(qp) ::res - if (dim_ < 0 .or. dim_ > 2 ) call error_stop("ERROR (mean): invalid argument (dim) ") + res = sum(mat) / real(size(mat), qp) - allocate(res(size(mat, dim_))) +end function mean_2_all_qp_qp - if (dim_ == 1) then - do i=1, size(mat, dim_) - res(i) = mean_1_sp_sp(mat(i,:)) - end do - else if (dim_ == 2) then - do i=1, size(mat, dim_) +module function mean_2_sp_sp(mat, dim) result(res) + real(sp), intent(in) :: mat(:,:) + integer, intent(in) :: dim + real(sp) :: res(size(mat)/size(mat, dim)) + + integer :: i + + if (dim == 1) then + do i=1, size(mat)/size(mat, dim) res(i) = mean_1_sp_sp(mat(:,i)) end do + else if (dim == 2) then + do i=1, size(mat)/size(mat, dim) + res(i) = mean_1_sp_sp(mat(i,:)) + end do end if end function mean_2_sp_sp module function mean_2_dp_dp(mat, dim) result(res) real(dp), intent(in) :: mat(:,:) - integer, intent(in), optional :: dim - real(dp), allocatable ::res(:) + integer, intent(in) :: dim + real(dp) :: res(size(mat)/size(mat, dim)) integer :: i - integer :: dim_ - - dim_ = optval(dim, 1) - - if (dim_ < 0 .or. dim_ > 2 ) call error_stop("ERROR (mean): invalid argument (dim) ") - allocate(res(size(mat, dim_))) - - if (dim_ == 1) then - do i=1, size(mat, dim_) - res(i) = mean_1_dp_dp(mat(i,:)) - end do - else if (dim_ == 2) then - do i=1, size(mat, dim_) + if (dim == 1) then + do i=1, size(mat)/size(mat, dim) res(i) = mean_1_dp_dp(mat(:,i)) end do + else if (dim == 2) then + do i=1, size(mat)/size(mat, dim) + res(i) = mean_1_dp_dp(mat(i,:)) + end do end if end function mean_2_dp_dp module function mean_2_qp_qp(mat, dim) result(res) real(qp), intent(in) :: mat(:,:) - integer, intent(in), optional :: dim - real(qp), allocatable ::res(:) + integer, intent(in) :: dim + real(qp) :: res(size(mat)/size(mat, dim)) integer :: i - integer :: dim_ - - dim_ = optval(dim, 1) - - if (dim_ < 0 .or. dim_ > 2 ) call error_stop("ERROR (mean): invalid argument (dim) ") - - allocate(res(size(mat, dim_))) - if (dim_ == 1) then - do i=1, size(mat, dim_) - res(i) = mean_1_qp_qp(mat(i,:)) - end do - else if (dim_ == 2) then - do i=1, size(mat, dim_) + if (dim == 1) then + do i=1, size(mat)/size(mat, dim) res(i) = mean_1_qp_qp(mat(:,i)) end do + else if (dim == 2) then + do i=1, size(mat)/size(mat, dim) + res(i) = mean_1_qp_qp(mat(i,:)) + end do end if end function mean_2_qp_qp diff --git a/src/stdlib_experimental_stat_mean.fypp.f90 b/src/stdlib_experimental_stat_mean.fypp.f90 index 1e059d3ed..98f96f482 100644 --- a/src/stdlib_experimental_stat_mean.fypp.f90 +++ b/src/stdlib_experimental_stat_mean.fypp.f90 @@ -22,28 +22,31 @@ end function mean_1_${k1}$_${k1}$ #:endfor #:for i1, k1, t1 in ikt -module function mean_2_${k1}$_${k1}$(mat, dim) result(res) +module function mean_2_all_${k1}$_${k1}$(mat) result(res) ${t1}$, intent(in) :: mat(:,:) - integer, intent(in), optional :: dim - ${t1}$, allocatable ::res(:) + ${t1}$ ::res - integer :: i - integer :: dim_ + res = sum(mat) / real(size(mat), ${k1}$) - dim_ = optval(dim, 1) +end function mean_2_all_${k1}$_${k1}$ +#:endfor - if (dim_ < 0 .or. dim_ > 2 ) call error_stop("ERROR (mean): invalid argument (dim) ") +#:for i1, k1, t1 in ikt +module function mean_2_${k1}$_${k1}$(mat, dim) result(res) + ${t1}$, intent(in) :: mat(:,:) + integer, intent(in) :: dim + ${t1}$ :: res(size(mat)/size(mat, dim)) - allocate(res(size(mat, dim_))) + integer :: i - if (dim_ == 1) then - do i=1, size(mat, dim_) - res(i) = mean_1_${k1}$_${k1}$(mat(i,:)) - end do - else if (dim_ == 2) then - do i=1, size(mat, dim_) + if (dim == 1) then + do i=1, size(mat)/size(mat, dim) res(i) = mean_1_${k1}$_${k1}$(mat(:,i)) end do + else if (dim == 2) then + do i=1, size(mat)/size(mat, dim) + res(i) = mean_1_${k1}$_${k1}$(mat(i,:)) + end do end if end function mean_2_${k1}$_${k1}$ diff --git a/src/tests/stat/test_mean.f90 b/src/tests/stat/test_mean.f90 index 1dcaba580..4a0e83e9b 100644 --- a/src/tests/stat/test_mean.f90 +++ b/src/tests/stat/test_mean.f90 @@ -11,14 +11,16 @@ program test_mean !sp call loadtxt("array1.dat", s) -call assert(sum( mean(s) - [1.5_sp, 3.5_sp, 5.5_sp, 7.5_sp] ) == 0.0_sp) -call assert(sum( mean(s, dim = 2) - [4.0_sp, 5.0_sp] ) == 0.0_sp) +call assert( mean(s) - 4.5_sp == 0.0_sp) +call assert(sum( mean(s, dim = 1) - [4.0_sp, 5.0_sp] ) == 0.0_sp) +call assert(sum( mean(s, dim = 2) - [1.5_dp, 3.5_dp, 5.5_dp, 7.5_dp] ) == 0.0_sp) !dp call loadtxt("array1.dat", d) -call assert(sum( mean(d) - [1.5_dp, 3.5_dp, 5.5_dp, 7.5_dp] ) == 0.0_dp) -call assert(sum( mean(d, dim = 2) - [4.0_dp, 5.0_dp] ) == 0.0_dp) +call assert(mean(d) - 4.5_dp == 0.0_dp) +call assert(sum( mean(d, dim = 1) - [4.0_dp, 5.0_dp] ) == 0.0_dp) +call assert(sum( mean(d, dim = 2) - [1.5_dp, 3.5_dp, 5.5_dp, 7.5_dp] ) == 0.0_dp) contains