diff --git a/src/Statistics.jl b/src/Statistics.jl index 07151e47..4843feb1 100644 --- a/src/Statistics.jl +++ b/src/Statistics.jl @@ -174,46 +174,52 @@ realXcY(x::Complex, y::Complex) = real(x)*real(y) + imag(x)*imag(y) var(iterable; corrected::Bool=true, mean=nothing) = _var(iterable, corrected, mean) -function _var(iterable, corrected::Bool, mean) +function _var(iterable, corrected::Bool, ::Nothing) y = iterate(iterable) if y === nothing T = eltype(iterable) return oftype((abs2(zero(T)) + abs2(zero(T)))/2, NaN) end + # Use Welford algorithm as seen in (among other places) + # Knuth's TAOCP, Vol 2, page 232, 3rd edition. count = 1 value, state = y y = iterate(iterable, state) - if mean === nothing - # Use Welford algorithm as seen in (among other places) - # Knuth's TAOCP, Vol 2, page 232, 3rd edition. - M = value / 1 - S = real(zero(M)) - while y !== nothing - value, state = y - y = iterate(iterable, state) - count += 1 - new_M = M + (value - M) / count - S = S + realXcY(value - M, value - new_M) - M = new_M - end - return S / (count - Int(corrected)) - elseif isa(mean, Number) # mean provided - # Cannot use a compensated version, e.g. the one from - # "Updating Formulae and a Pairwise Algorithm for Computing Sample Variances." - # by Chan, Golub, and LeVeque, Technical Report STAN-CS-79-773, - # Department of Computer Science, Stanford University, - # because user can provide mean value that is different to mean(iterable) - sum2 = abs2(value - mean::Number) - while y !== nothing - value, state = y - y = iterate(iterable, state) - count += 1 - sum2 += abs2(value - mean) - end - return sum2 / (count - Int(corrected)) - else - throw(ArgumentError("invalid value of mean, $(mean)::$(typeof(mean))")) + M = value / 1 + S = real(zero(M)) + while y !== nothing + value, state = y + y = iterate(iterable, state) + count += 1 + new_M = M + (value - M) / count + S = S + realXcY(value - M, value - new_M) + M = new_M + end + return S / (count - corrected) +end + +function _var(iterable, corrected::Bool, mean::Number) + y = iterate(iterable) + if y === nothing + T = eltype(iterable) + return oftype((abs2(zero(T)) + abs2(zero(T)))/2, NaN) + end + # Cannot use a compensated version, e.g. the one from + # "Updating Formulae and a Pairwise Algorithm for Computing Sample Variances." + # by Chan, Golub, and LeVeque, Technical Report STAN-CS-79-773, + # Department of Computer Science, Stanford University, + # because user can provide mean value that is different to mean(iterable) + count = 1 + value, state = y + y = iterate(iterable, state) + sum2 = abs2(value - mean) + while y !== nothing + value, state = y + y = iterate(iterable, state) + count += 1 + sum2 += abs2(value - mean) end + return sum2 / (count - corrected) end centralizedabs2fun(m) = x -> abs2.(x - m) @@ -296,11 +302,13 @@ over dimensions, and `m` may contain means for each dimension of `itr`. """ varm(A::AbstractArray, m::AbstractArray; corrected::Bool=true, dims=:) = _varm(A, m, corrected, dims) +varm(A::AbstractArray, m; corrected::Bool=true) = _varm(A, m, corrected, :) + +varm(iterable, m; corrected::Bool=true) = _var(iterable, corrected, m) + _varm(A::AbstractArray{T}, m, corrected::Bool, region) where {T} = varm!(Base.reducedim_init(t -> abs2(t)/2, +, A, region), A, m; corrected=corrected) -varm(A::AbstractArray, m; corrected::Bool=true) = _varm(A, m, corrected, :) - function _varm(A::AbstractArray{T}, m, corrected::Bool, ::Colon) where T n = length(A) n == 0 && return oftype((abs2(zero(T)) + abs2(zero(T)))/2, NaN) @@ -334,13 +342,15 @@ over dimensions, and `mean` may contain means for each dimension of `itr`. """ var(A::AbstractArray; corrected::Bool=true, mean=nothing, dims=:) = _var(A, corrected, mean, dims) -_var(A::AbstractArray, corrected::Bool, mean, dims) = - varm(A, something(mean, Statistics.mean(A, dims=dims)); corrected=corrected, dims=dims) +_var(A::AbstractArray, corrected::Bool, mean, dims) = _varm(A, mean, corrected, dims) -_var(A::AbstractArray, corrected::Bool, mean, ::Colon) = - real(varm(A, something(mean, Statistics.mean(A)); corrected=corrected)) +_var(A::AbstractArray, corrected::Bool, mean, ::Colon) = real(_var(A, corrected, mean)) -varm(iterable, m; corrected::Bool=true) = _var(iterable, corrected, m) +_var(A::AbstractArray, corrected::Bool, ::Nothing, dims) = + _varm(A, Statistics.mean(A, dims=dims), corrected, dims) + +_var(A::AbstractArray, corrected::Bool, ::Nothing, ::Colon) = + real(_var(A, corrected, Statistics.mean(A))) ## variances over ranges