Skip to content

Commit

Permalink
add dot(x, A, y) for sparse matrices (#1263)
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-roehrich authored Oct 28, 2023
1 parent ae9e656 commit 6dcd2c8
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/src/sparse/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ Various products:
*(::SRow{T}, ::SMat{T}) where {T}
```

```@docs
dot(::SRow{T}, ::SMat{T}, ::SRow{T}) where T
dot(::MatrixElem{T}, ::SMat{T}, ::MatrixElem{T}) where T
dot(::AbstractVector{T}, ::SMat{T}, ::AbstractVector{T}) where T
```

Other:
```@docs
sparse(::SMat)
Expand Down
85 changes: 85 additions & 0 deletions src/Sparse/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,91 @@ function *(A::SMat, b)
return A * base_ring(A)(b)
end

################################################################################
#
# Dot product
#
################################################################################

@doc raw"""
dot(x::SRow{T}, A::SMat{T}, y::SRow{T}) where T -> T
Return the generalized dot product `dot(x, A*y)`.
"""
function dot(x::SRow{T}, A::SMat{T}, y::SRow{T}) where T
v = zero(T)
px = 1
for i in 1:length(A.rows)
while px <= length(x.pos) && x.pos[px] < i
px += 1
end
if px > length(x.pos)
break
elseif x.pos[px] > i
continue
end

s = zero(T)
py = 1
for j in 1:length(A[i].pos)
while py <= length(y.pos) && y.pos[py] < A[i].pos[j]
py += 1
end
if py > length(y.pos)
break
elseif y.pos[py] > A[i].pos[j]
continue
end

s += A[i].values[j] * y.values[py]
end

v += x.values[px] * s
end

return v
end

@doc raw"""
dot(x::AbstractVector{T}, A::SMat{T}, y::AbstractVector{T}) where T -> T
Return the generalized dot product `dot(x, A*y)`.
"""
function dot(x::AbstractVector{T}, A::SMat{T}, y::AbstractVector{T}) where T
@req length(x) == nrows(A) && ncols(A) <= length(y) "incompatible matrix dimensions"

v = zero(T)
for i in 1:length(A.rows)
s = T(0)
for j in 1:length(A[i].pos)
s += A[i].values[j] * y[A[i].pos[j]]
end
v += x[i] * s
end

return v
end

@doc raw"""
dot(x::MatrixElem{T}, A::SMat{T}, y::MatrixElem{T}) where T -> T
Return the generalized dot product `dot(x, A*y)`.
"""
function dot(x::MatrixElem{T}, A::SMat{T}, y::MatrixElem{T}) where T
@req length(x) == nrows(A) && ncols(A) <= length(y) "incompatible matrix dimensions"

v = zero(T)
for i in 1:length(A.rows)
s = zero(T)
for j in 1:length(A[i].pos)
s += A[i].values[j] * y[A[i].pos[j]]
end
v += x[i] * s
end

return v
end

################################################################################
#
# Submatrix
Expand Down
33 changes: 33 additions & 0 deletions test/Sparse/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,39 @@ using Hecke.SparseArrays
E = @inferred D * R(3)
@test E == sparse_matrix(R, [3 0 0; 0 0 3; 0 0 0])

# Dot product

D = sparse_matrix(ZZ, [4 -2; -2 2])
E = sparse_matrix(ZZ, [3 0 0; 0 3 0])

@test dot(sparse_row(ZZ, [1], [1]), D, sparse_row(ZZ, [1], [1])) == 4
@test dot(sparse_row(ZZ, [1, 2], [1, 1]), D, sparse_row(ZZ, [1, 2], [1, 2])) == 2
@test dot(sparse_row(ZZ, [1], [1]), D, sparse_row(ZZ, [2], [1])) == -2

@test dot(sparse_row(ZZ, [1, 4], [1, 2]), D, sparse_row(ZZ, [2], [1])) == -2
@test dot(sparse_row(ZZ, [1, 4], [1, 2]), E, sparse_row(ZZ, [2], [1])) == 0
@test dot(sparse_row(ZZ, [1, 2], [1, 2]), E, sparse_row(ZZ, [2], [1])) == 6

@test dot(ZZRingElem[1, 0], D, ZZRingElem[1, 0]) == 4
@test dot(ZZRingElem[1, 1], D, ZZRingElem[1, 2]) == 2
@test dot(ZZRingElem[1, 0], D, ZZRingElem[0, 1]) == -2
@test dot(ZZRingElem[1, 0], E, ZZRingElem[0, 1, 2]) == 0
@test dot(ZZRingElem[0, 1], E, ZZRingElem[0, 1, 2]) == 3

@test dot(ZZ[1 0], D, ZZ[1 0]) == 4
@test dot(ZZ[1; 1], D, ZZ[1 2]) == 2
@test dot(ZZ[1 0], D, ZZ[0; 1]) == -2
@test dot(ZZ[1 0], E, ZZ[0 1 2]) == 0
@test dot(ZZ[0 1], E, ZZ[0 1 2]) == 3

@test_throws ArgumentError dot(ZZRingElem[1], D, ZZRingElem[0, 1])
@test_throws ArgumentError dot(ZZRingElem[1, 0], D, ZZRingElem[0])
@test_throws ArgumentError dot(ZZRingElem[1, 0, 2], E, ZZRingElem[0, 1])

@test_throws ArgumentError dot(ZZ[1 0 0], D, ZZ[1 0])
@test_throws ArgumentError dot(ZZ[0 1 2], E, ZZ[0 1])
@test_throws ArgumentError dot(ZZ[1 0; 0 0], D, ZZ[1 0 0 0])

# Submatrix

D = sparse_matrix(FlintZZ, [1 5 3; 0 0 0; 0 1 0])
Expand Down

0 comments on commit 6dcd2c8

Please sign in to comment.