Skip to content

Commit

Permalink
Add predicate-reduce specializations over sparse matrix columns (Juli…
Browse files Browse the repository at this point in the history
  • Loading branch information
jmert authored Oct 22, 2020
1 parent 31026e8 commit 3e7c8f3
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
34 changes: 34 additions & 0 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,40 @@ function _mapreducecols!(f, op::typeof(+), R::AbstractArray, A::AbstractSparseMa
R
end

# any(pred, A, dims = 1) => mapreduce(pred, |, A, dims = 1)
# final argument `post` is to allow post-mapping each columnar mapreduce
function _mapreducerows!(pred::P, ::typeof(|), R::AbstractMatrix{Bool}, A::AbstractSparseMatrixCSC{Tv},
post::F = identity) where {P, F, Tv}
nzval = nonzeros(A)
colptr = getcolptr(A)
m, n = size(A)
@inbounds for ii in 1:n
bi, ei = colptr[ii], colptr[ii+1]
len = ei - bi
# An empty column is trivial
if len == 0
R[1, ii] = post(pred(zero(Tv)))
continue
end
# If predicate on zero is true, then sparse column can be short-circuited
if pred(zero(Tv)) && len < m
R[1, ii] = post(true)
continue
end
# Otherwise reduce over the stored values
r = false
for jj in bi:(ei - 1)
r = pred(nzval[jj])
r && break
end
R[1, ii] = post(r)
end
return R
end
# all(pred, A, dims = 1) => mapreduce(pred, &, A, dims = 1) == .!mapreduce(!pred, |, A, dims = 1)
_mapreducerows!(pred::P, ::typeof(&), R::AbstractMatrix{Bool},
A::AbstractSparseMatrixCSC) where {P} = _mapreducerows!(!pred, |, R, A, !)

# findmax/min and argmax/min methods
# find first zero value in sparse matrix - return linear index in full matrix
# non-structural zeros are identified by x == 0 in line with the sparse constructors.
Expand Down
22 changes: 22 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2970,6 +2970,28 @@ end
@test SparseArrays.getcolptr(vA) == SparseArrays.getcolptr(A[:, 1:5])
end

@testset "any/all predicates over dims = 1" begin
As = sparse([2, 3], [2, 3], [0.0, 1.0]) # empty, structural zero, non-zero
Ad = Matrix(As)
Bs = copy(As) # like As, but full column
Bs[:,3] .= 1.0
Bd = Matrix(Bs)
Cs = copy(Bs) # like Bs, but full column is all structural zeros
Cs[:,3] .= 0.0
Cd = Matrix(Cs)

@testset "any($(repr(pred)))" for pred in (iszero, !iszero, >(-1.0), !=(1.0))
@test any(pred, As, dims = 1) == any(pred, Ad, dims = 1)
@test any(pred, Bs, dims = 1) == any(pred, Bd, dims = 1)
@test any(pred, Cs, dims = 1) == any(pred, Cd, dims = 1)
end
@testset "all($(repr(pred)))" for pred in (iszero, !iszero, >(-1.0), !=(1.0))
@test all(pred, As, dims = 1) == all(pred, Ad, dims = 1)
@test all(pred, Bs, dims = 1) == all(pred, Bd, dims = 1)
@test all(pred, Cs, dims = 1) == all(pred, Cd, dims = 1)
end
end

@testset "mapreducecols" begin
n = 20
m = 10
Expand Down

0 comments on commit 3e7c8f3

Please sign in to comment.