Skip to content

Commit

Permalink
add init keyword argument to count() (#37461)
Browse files Browse the repository at this point in the history
  • Loading branch information
simeonschaub authored Oct 29, 2020
1 parent d5ad85a commit 506fbdf
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 21 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ New library features
* The postfix operator `'ᵀ` can now be used as an alias for `transpose` ([#38043]).
* `keys(io::IO)` has been added, which returns all keys of `io` if `io` is an `IOContext` and an empty
`Base.KeySet` otherwise ([#37753]).
* `count` now accepts an optional `init` argument to control the accumulation type ([#37461]).

Standard library changes
------------------------
Expand Down
8 changes: 4 additions & 4 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1386,15 +1386,15 @@ circshift!(B::BitVector, i::Integer) = circshift!(B, B, i)

## count & find ##

function bitcount(Bc::Vector{UInt64})
n = 0
function bitcount(Bc::Vector{UInt64}; init::T=0) where {T}
n::T = init
@inbounds for i = 1:length(Bc)
n += count_ones(Bc[i])
n = (n + count_ones(Bc[i])) % T
end
return n
end

count(B::BitArray) = bitcount(B.chunks)
count(B::BitArray; init=0) = bitcount(B.chunks; init)

function unsafe_bitfindnext(Bc::Vector{UInt64}, start::Int)
chunk_start = _div64(start-1)+1
Expand Down
32 changes: 19 additions & 13 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -938,12 +938,15 @@ end
_bool(f) = x->f(x)::Bool

"""
count(p, itr) -> Integer
count(itr) -> Integer
count([f=identity,] itr; init=0) -> Integer
Count the number of elements in `itr` for which predicate `p` returns `true`.
If `p` is omitted, counts the number of `true` elements in `itr` (which
should be a collection of boolean values).
Count the number of elements in `itr` for which the function `f` returns `true`.
If `f` is omitted, count the number of `true` elements in `itr` (which
should be a collection of boolean values). `init` optionally specifies the value
to start counting from and therefore also determines the output type.
!!! compat "Julia 1.6"
`init` keyword was added in Julia 1.6.
# Examples
```jldoctest
Expand All @@ -952,32 +955,35 @@ julia> count(i->(4<=i<=6), [2,3,4,5,6])
julia> count([true, false, true, true])
3
julia> count(>(3), 1:7, init=0x03)
0x07
```
"""
count(itr) = count(identity, itr)
count(itr; init=0) = count(identity, itr; init)

count(f, itr) = _simple_count(f, itr)
count(f, itr; init=0) = _simple_count(f, itr, init)

function _simple_count(pred, itr)
n = 0
function _simple_count(pred, itr, init::T) where {T}
n::T = init
for x in itr
n += pred(x)::Bool
end
return n
end

function count(::typeof(identity), x::Array{Bool})
n = 0
function _simple_count(::typeof(identity), x::Array{Bool}, init::T=0) where {T}
n::T = init
chunks = length(x) ÷ sizeof(UInt)
mask = 0x0101010101010101 % UInt
GC.@preserve x begin
ptr = Ptr{UInt}(pointer(x))
for i in 1:chunks
n += count_ones(unsafe_load(ptr, i) & mask)
n = (n + count_ones(unsafe_load(ptr, i) & mask)) % T
end
end
for i in sizeof(UInt)*chunks+1:length(x)
n += x[i]
n = (n + x[i]) % T
end
return n
end
11 changes: 7 additions & 4 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ dimensions.
!!! compat "Julia 1.5"
`dims` keyword was added in Julia 1.5.
!!! compat "Julia 1.6"
`init` keyword was added in Julia 1.6.
# Examples
```jldoctest
julia> A = [1 2; 3 4]
Expand All @@ -386,11 +389,11 @@ julia> count(<=(2), A, dims=2)
0
```
"""
count(A::AbstractArrayOrBroadcasted; dims=:) = count(identity, A, dims=dims)
count(f, A::AbstractArrayOrBroadcasted; dims=:) = _count(f, A, dims)
count(A::AbstractArrayOrBroadcasted; dims=:, init=0) = count(identity, A; dims, init)
count(f, A::AbstractArrayOrBroadcasted; dims=:, init=0) = _count(f, A, dims, init)

_count(f, A::AbstractArrayOrBroadcasted, dims::Colon) = _simple_count(f, A)
_count(f, A::AbstractArrayOrBroadcasted, dims) = mapreduce(_bool(f), add_sum, A, dims=dims, init=0)
_count(f, A::AbstractArrayOrBroadcasted, dims::Colon, init) = _simple_count(f, A, init)
_count(f, A::AbstractArrayOrBroadcasted, dims, init) = mapreduce(_bool(f), add_sum, A; dims, init)

"""
count!([f=identity,] r, A)
Expand Down
2 changes: 2 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,8 @@ timesofar("datamove")
@check_bit_operation findall(falses(t)) ret_type
@check_bit_operation findall(bitrand(t)) ret_type
end

@test count(trues(2, 2), init=0x03) === 0x07
end

timesofar("find")
Expand Down
6 changes: 6 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,12 @@ struct NonFunctionIsZero end
@test count(NonFunctionIsZero(), [0]) == 1
@test count(NonFunctionIsZero(), [1]) == 0

@test count(Iterators.repeated(true, 3), init=0x04) === 0x07
@test count(!=(2), Iterators.take(1:7, 3), init=Int32(0)) === Int32(2)
@test count(identity, [true, false], init=Int8(5)) === Int8(6)
@test count(!, [true false; false true], dims=:, init=Int16(0)) === Int16(2)
@test isequal(count(identity, [true false; false true], dims=2, init=UInt(4)), reshape(UInt[5, 5], 2, 1))

## cumsum, cummin, cummax

z = rand(10^6)
Expand Down
12 changes: 12 additions & 0 deletions test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
@test @inferred(maximum(abs, Areduc, dims=region)) safe_maxabs(Areduc, region)
@test @inferred(minimum(abs, Areduc, dims=region)) safe_minabs(Areduc, region)
@test @inferred(count(!, Breduc, dims=region)) safe_count(.!Breduc, region)

@test isequal(
@inferred(count(Breduc, dims=region, init=0x02)),
safe_count(Breduc, region) .% UInt8 .+ 0x02,
)
@test isequal(
@inferred(count(!, Breduc, dims=region, init=Int16(0))),
safe_count(.!Breduc, region) .% Int16,
)
end

# Combining dims and init
Expand Down Expand Up @@ -446,3 +455,6 @@ end
@test_throws TypeError count([1], dims=1)
@test_throws TypeError count!([1], [1])
end

@test @inferred(count(false:true, dims=:, init=0x0004)) === 0x0005
@test @inferred(count(isodd, reshape(1:9, 3, 3), dims=:, init=Int128(0))) === Int128(5)

0 comments on commit 506fbdf

Please sign in to comment.