diff --git a/stdlib/SparseArrays/src/sparsematrix.jl b/stdlib/SparseArrays/src/sparsematrix.jl index 7bbf5a28828c2..ff05578d50e0f 100644 --- a/stdlib/SparseArrays/src/sparsematrix.jl +++ b/stdlib/SparseArrays/src/sparsematrix.jl @@ -1960,6 +1960,40 @@ function _mapreducecols!(f, op::typeof(+), R::AbstractArray, A::AbstractSparseMa R end +# any(pred, A, dims = 1) => mapreduce(pred, |, A, dims = 1) +# final argument `post` is to allow post-mapping each columnar mapreduce +function _mapreducerows!(pred::P, ::typeof(|), R::AbstractMatrix{Bool}, A::AbstractSparseMatrixCSC{Tv}, + post::F = identity) where {P, F, Tv} + nzval = nonzeros(A) + colptr = getcolptr(A) + m, n = size(A) + @inbounds for ii in 1:n + bi, ei = colptr[ii], colptr[ii+1] + len = ei - bi + # An empty column is trivial + if len == 0 + R[1, ii] = post(pred(zero(Tv))) + continue + end + # If predicate on zero is true, then sparse column can be short-circuited + if pred(zero(Tv)) && len < m + R[1, ii] = post(true) + continue + end + # Otherwise reduce over the stored values + r = false + for jj in bi:(ei - 1) + r = pred(nzval[jj]) + r && break + end + R[1, ii] = post(r) + end + return R +end +# all(pred, A, dims = 1) => mapreduce(pred, &, A, dims = 1) == .!mapreduce(!pred, |, A, dims = 1) +_mapreducerows!(pred::P, ::typeof(&), R::AbstractMatrix{Bool}, + A::AbstractSparseMatrixCSC) where {P} = _mapreducerows!(!pred, |, R, A, !) + # findmax/min and argmax/min methods # find first zero value in sparse matrix - return linear index in full matrix # non-structural zeros are identified by x == 0 in line with the sparse constructors. diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index aee7799f83264..535cad579c4cc 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -2970,6 +2970,28 @@ end @test SparseArrays.getcolptr(vA) == SparseArrays.getcolptr(A[:, 1:5]) end +@testset "any/all predicates over dims = 1" begin + As = sparse([2, 3], [2, 3], [0.0, 1.0]) # empty, structural zero, non-zero + Ad = Matrix(As) + Bs = copy(As) # like As, but full column + Bs[:,3] .= 1.0 + Bd = Matrix(Bs) + Cs = copy(Bs) # like Bs, but full column is all structural zeros + Cs[:,3] .= 0.0 + Cd = Matrix(Cs) + + @testset "any($(repr(pred)))" for pred in (iszero, !iszero, >(-1.0), !=(1.0)) + @test any(pred, As, dims = 1) == any(pred, Ad, dims = 1) + @test any(pred, Bs, dims = 1) == any(pred, Bd, dims = 1) + @test any(pred, Cs, dims = 1) == any(pred, Cd, dims = 1) + end + @testset "all($(repr(pred)))" for pred in (iszero, !iszero, >(-1.0), !=(1.0)) + @test all(pred, As, dims = 1) == all(pred, Ad, dims = 1) + @test all(pred, Bs, dims = 1) == all(pred, Bd, dims = 1) + @test all(pred, Cs, dims = 1) == all(pred, Cd, dims = 1) + end +end + @testset "mapreducecols" begin n = 20 m = 10