Skip to content

Commit

Permalink
check JET errors in view
Browse files Browse the repository at this point in the history
  • Loading branch information
guimarqu committed Aug 2, 2023
1 parent cf74929 commit 3886bb7
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 23 deletions.
28 changes: 23 additions & 5 deletions src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,17 @@ function Base.setindex!(m::DynamicSparseMatrix{K,L,T}, val, row::K, col::L) wher
m.n = max(m.n, col)
end
if m.fillmode
addelem!(m.buffer, row, col, val)
buffer = m.buffer
@assert !isnothing(buffer)
addelem!(buffer, row, col, val)
else
m.colmajor[row, col] = val
m.rowmajor[col, row] = val
colmajor = m.colmajor
@assert !isnothing(colmajor)
colmajor[row, col] = val

rowmajor = m.rowmajor
@assert !isnothing(rowmajor)
rowmajor[col, row] = val
end
return m
end
Expand All @@ -61,12 +68,23 @@ function Base.getindex(m::DynamicSparseMatrix, row, col)
end

function Base.view(m::DynamicSparseMatrix{K,L,T}, row::K, ::Colon) where {K,L,T}
return m.fillmode ? view(m.buffer, row, :) : view(m.rowmajor, :, row)
if m.fillmode
# Do not allow to create a view on a buffer, otherwise the method is type unstable.
error("Matrix is in fill mode, cannot create a view. However, you can use the view method on the buffer.")
# buffer = m.buffer
# @assert !isnothing(buffer)
# return view(buffer, row, :)
end
rowmajor = m.rowmajor
@assert !isnothing(rowmajor)
return view(rowmajor, :, row)
end

function Base.view(m::DynamicSparseMatrix{K,L,T}, ::Colon, col::L) where {K,L,T}
m.fillmode && error("View of a column not available in fill mode.")
return view(m.colmajor, :, col)
colmajor = m.colmajor
@assert !isnothing(colmajor)
return view(colmajor, :, col)
end

Base.ndims(m::DynamicSparseMatrix) = 2
Expand Down
35 changes: 24 additions & 11 deletions src/pcsr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,30 @@ function addpartition!(pcsc::PackedCSC{K,T}) where {K,T}
end

function addpartition!(pcsc::PackedCSC{K,T}, prev_sem_id::Int) where {K,T}
semaphores = pcsc.semaphores
@assert !isnothing(semaphores)
sem_key = semaphore_key(K)
nb_semaphores = length(pcsc.semaphores)
nb_semaphores = length(semaphores)
sem_pos = 0
if pcsc.semaphores[prev_sem_id + 1] === nothing
next_sem_id = _nextnonemptypos(pcsc.semaphores, prev_sem_id + 1)
sem_pos = pcsc.semaphores[next_sem_id] - 1 # insert the new semaphore in the pma.array just before the next one
semaphore_target = semaphores[prev_sem_id + 1]
if semaphore_target === nothing
next_sem_id = _nextnonemptypos(semaphores, prev_sem_id + 1)
next_semaphore = semaphores[next_sem_id]
@assert !isnothing(next_semaphore)
sem_pos = next_semaphore - 1 # insert the new semaphore in the pma.array just before the next one
else
sem_pos = pcsc.semaphores[prev_sem_id + 1] - 1 # insert the new semaphore just before the next one
resize!(pcsc.semaphores, nb_semaphores + 1) # create room for the position of the new semaphore
sem_pos = semaphore_target - 1 # insert the new semaphore just before the next one
resize!(semaphores, nb_semaphores + 1) # create room for the position of the new semaphore
for i in nb_semaphores:-1:(prev_sem_id+1)
moved_sem_pos = pcsc.semaphores[i]
pcsc.semaphores[i+1] = pcsc.semaphores[i]
semaphores[i+1] = semaphores[i]
pcsc.pma.array[moved_sem_pos] = (sem_key, T(i+1))
end
end
pcsc.nb_partitions += 1
sem_val = T(prev_sem_id+1)
insert_pos, new_elem = _insert!(pcsc.pma.array, sem_key, sem_val, sem_pos, pcsc.semaphores)
pcsc.semaphores[prev_sem_id+1] = insert_pos
semaphores[prev_sem_id+1] = insert_pos
if new_elem
pcsc.pma.nb_elements += 1
win_start, win_end, nbcells = _look_for_rebalance!(pcsc.pma, insert_pos)
Expand Down Expand Up @@ -162,12 +167,19 @@ function addcolumn!(mpcsc::MappedPackedCSC{K,L,T}, col::L, prev_col_pos::Int) wh
return col_pos
end

_pos_of_partition_start(pcsc, partition) = pcsc.semaphores[partition]
function _pos_of_partition_start(pcsc, partition)
partition_start_pos = pcsc.semaphores[partition]
@assert !isnothing(partition_start_pos)
return partition_start_pos
end

function _pos_of_partition_end(pcsc, partition)
pos = length(pcsc.pma.array)
next_partition = _nextnonemptypos(pcsc.semaphores, partition)
if next_partition != 0
pos = pcsc.semaphores[next_partition] - 1
next_partition_start_pos = pcsc.semaphores[next_partition]
@assert !isnothing(next_partition_start_pos)
pos = next_partition_start_pos - 1
end
return pos
end
Expand Down Expand Up @@ -343,7 +355,8 @@ function _dynamicsparse(
) where {K,L,T}
!always_use_map && error("TODO issue #2.")

p = sortperm(collect(zip(J,I)), alg=QuickSort) # Columns first
ind = collect(zip(J,I))
p = sortperm(ind, alg=QuickSort) # Columns first
@inbounds I = I[p]
@inbounds J = J[p]
@inbounds V = V[p]
Expand Down
4 changes: 3 additions & 1 deletion src/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ function Base.view(::Buffer{L,K,T}, ::Colon, col::L) where {K,L,T}
throw(ArgumentError("Cannot view a column of the BufferView."))
end

function Base.iterate(bf::BufferView, state = 1)
Base.eltype(::Type{BufferView{L,K,T}}) where {K,L,T} = Tuple{L,T}
Base.length(bf::BufferView) = length(bf.colids)
function Base.iterate(bf::BufferView, state::Int = 1)
state > length(bf.vals) && return nothing
return ((bf.colids[state], bf.vals[state]), state + 1)
end
2 changes: 2 additions & 0 deletions test/unit/unitests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ function unit_tests()

@testset "Views" begin
test_views()
test_buffer_views()
@test_call test_buffer_views()
end

@testset "Sparse Matrix Vector Multiplication" begin
Expand Down
43 changes: 37 additions & 6 deletions test/unit/views.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
function test_views()
I = [1, 1, 2, 4, 3, 5, 1, 4, 1, 5, 1, 5, 4, 4, 3, 9, 1]
J = [4, 3, 3, 7, 18, 9, 3, 18, 4, 2, 3, 1, 7, 3, 3, 3, 18]
V = [1, 8, 10, 2, -5, 3, 2, 1, 1, 1, 5, 3, 2, 1, 7, 8, 1]
matrix = dynamicsparse(I,J,V)

function _test_views(matrix)
# get the row with id 5
ids = Int[]
vals = Int[]
Expand Down Expand Up @@ -33,4 +28,40 @@ function test_views()
end
@test ids == [1,3,4]
@test vals == [1,-5,1]
return
end

function test_views()
I = [1, 1, 2, 4, 3, 5, 1, 4, 1, 5, 1, 5, 4, 4, 3, 9, 1]
J = [4, 3, 3, 7, 18, 9, 3, 18, 4, 2, 3, 1, 7, 3, 3, 3, 18]
V = [1, 8, 10, 2, -5, 3, 2, 1, 1, 1, 5, 3, 2, 1, 7, 8, 1]
matrix = dynamicsparse(I,J,V)

_test_views(matrix)
@test_call _test_views(matrix)
end

function test_buffer_views()
matrix = dynamicsparse(Int, Int, Int)

matrix[1,2] = 1
matrix[2,1] = 2
matrix[2,2] = 3
matrix[3,1] = 4
matrix[3,2] = 5
matrix[1,7] = 3

buffer = matrix.buffer
@assert !isnothing(buffer)

ids = Int[]
vals = Int[]
for (id, val) in view(buffer, 1, :)
push!(ids, id)
push!(vals, val)
end

@test ids == [2, 7]
@test vals == [1, 3]
return
end

0 comments on commit 3886bb7

Please sign in to comment.