Skip to content

Commit

Permalink
Check if Ti is sufficient for size of SparseMatrixCSC (JuliaLang#31118)
Browse files Browse the repository at this point in the history
* sparse additional checks to avoid narrow Ti
* Fixes JuliaLang#31024.
  • Loading branch information
KlausC authored and ViralBShah committed Feb 20, 2019
1 parent f37005a commit 4236774
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
18 changes: 14 additions & 4 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ struct SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}

function SparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::Vector{Ti}, rowval::Vector{Ti},
nzval::Vector{Tv}) where {Tv,Ti<:Integer}
@noinline throwsz(str, lbl, k) =
throw(ArgumentError("number of $str ($lbl) must be ≥ 0, got $k"))
m < 0 && throwsz("rows", 'm', m)
n < 0 && throwsz("columns", 'n', n)

sparse_check_Ti(m, n, Ti)
new(Int(m), Int(n), colptr, rowval, nzval)
end
end
Expand All @@ -35,6 +33,17 @@ function SparseMatrixCSC(m::Integer, n::Integer, colptr::Vector, rowval::Vector,
SparseMatrixCSC{Tv,Ti}(m, n, colptr, rowval, nzval)
end

function sparse_check_Ti(m::Integer, n::Integer, Ti::Type)
@noinline throwsz(str, lbl, k) =
throw(ArgumentError("number of $str ($lbl) must be ≥ 0, got $k"))
@noinline throwTi(str, lbl, k) =
throw(ArgumentError("$str ($lbl = $k) does not fit in Ti = $(Ti)"))
m < 0 && throwsz("rows", 'm', m)
n < 0 && throwsz("columns", 'n', n)
!isbitstype(Ti) || m typemax(Ti) || throwTi("number of rows", "m", m)
!isbitstype(Ti) || n typemax(Ti) || throwTi("number of columns", "n", n)
!isbitstype(Ti) || n*m+1 typemax(Ti) || throwTi("maximal nnz+1", "m*n+1", n*m+1)
end
size(S::SparseMatrixCSC) = (S.m, S.n)

# Define an alias for views of a SparseMatrixCSC which include all rows and a unit range of the columns.
Expand Down Expand Up @@ -590,6 +599,7 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
csccolptr::Vector{Ti}, cscrowval::Vector{Ti}, cscnzval::Vector{Tv}) where {Tv,Ti<:Integer}

require_one_based_indexing(I, J, V)
sparse_check_Ti(m, n, Ti)
# Compute the CSR form's row counts and store them shifted forward by one in csrrowptr
fill!(csrrowptr, Ti(0))
coolen = length(I)
Expand Down
8 changes: 8 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2510,4 +2510,12 @@ end
end
end

@testset "Ti cannot store all potential values #31024" begin
@test_throws ArgumentError SparseMatrixCSC(128, 1, [Int8(1), Int8(1)], Int8[], Int[])
@test_throws ArgumentError SparseMatrixCSC(12, 12, [Int8(1), Int8(1)], Int8[], Int[])
I1 = [Int8(i) for i in 1:20 for _ in 1:20]
J1 = [Int8(i) for _ in 1:20 for i in 1:20]
@test_throws ArgumentError sparse(I1, J1, zero(length(I1)zero(length(I1))))
end

end # module

0 comments on commit 4236774

Please sign in to comment.