From 6dcd2c895d8af1fdcaee59ed98bdc8517253d8f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20R=C3=B6hrich?= <47457568+felix-roehrich@users.noreply.github.com> Date: Sat, 28 Oct 2023 16:35:36 +0200 Subject: [PATCH] add `dot(x, A, y)` for sparse matrices (#1263) --- docs/src/sparse/intro.md | 6 +++ src/Sparse/Matrix.jl | 85 ++++++++++++++++++++++++++++++++++++++++ test/Sparse/Matrix.jl | 33 ++++++++++++++++ 3 files changed, 124 insertions(+) diff --git a/docs/src/sparse/intro.md b/docs/src/sparse/intro.md index 17154b9905..2a7cce961f 100644 --- a/docs/src/sparse/intro.md +++ b/docs/src/sparse/intro.md @@ -211,6 +211,12 @@ Various products: *(::SRow{T}, ::SMat{T}) where {T} ``` +```@docs +dot(::SRow{T}, ::SMat{T}, ::SRow{T}) where T +dot(::MatrixElem{T}, ::SMat{T}, ::MatrixElem{T}) where T +dot(::AbstractVector{T}, ::SMat{T}, ::AbstractVector{T}) where T +``` + Other: ```@docs sparse(::SMat) diff --git a/src/Sparse/Matrix.jl b/src/Sparse/Matrix.jl index 5c12dc2647..c7a4000480 100644 --- a/src/Sparse/Matrix.jl +++ b/src/Sparse/Matrix.jl @@ -723,6 +723,91 @@ function *(A::SMat, b) return A * base_ring(A)(b) end +################################################################################ +# +# Dot product +# +################################################################################ + +@doc raw""" + dot(x::SRow{T}, A::SMat{T}, y::SRow{T}) where T -> T + +Return the generalized dot product `dot(x, A*y)`. +""" +function dot(x::SRow{T}, A::SMat{T}, y::SRow{T}) where T + v = zero(T) + px = 1 + for i in 1:length(A.rows) + while px <= length(x.pos) && x.pos[px] < i + px += 1 + end + if px > length(x.pos) + break + elseif x.pos[px] > i + continue + end + + s = zero(T) + py = 1 + for j in 1:length(A[i].pos) + while py <= length(y.pos) && y.pos[py] < A[i].pos[j] + py += 1 + end + if py > length(y.pos) + break + elseif y.pos[py] > A[i].pos[j] + continue + end + + s += A[i].values[j] * y.values[py] + end + + v += x.values[px] * s + end + + return v +end + +@doc raw""" + dot(x::AbstractVector{T}, A::SMat{T}, y::AbstractVector{T}) where T -> T + +Return the generalized dot product `dot(x, A*y)`. +""" +function dot(x::AbstractVector{T}, A::SMat{T}, y::AbstractVector{T}) where T + @req length(x) == nrows(A) && ncols(A) <= length(y) "incompatible matrix dimensions" + + v = zero(T) + for i in 1:length(A.rows) + s = T(0) + for j in 1:length(A[i].pos) + s += A[i].values[j] * y[A[i].pos[j]] + end + v += x[i] * s + end + + return v +end + +@doc raw""" + dot(x::MatrixElem{T}, A::SMat{T}, y::MatrixElem{T}) where T -> T + +Return the generalized dot product `dot(x, A*y)`. +""" +function dot(x::MatrixElem{T}, A::SMat{T}, y::MatrixElem{T}) where T + @req length(x) == nrows(A) && ncols(A) <= length(y) "incompatible matrix dimensions" + + v = zero(T) + for i in 1:length(A.rows) + s = zero(T) + for j in 1:length(A[i].pos) + s += A[i].values[j] * y[A[i].pos[j]] + end + v += x[i] * s + end + + return v +end + ################################################################################ # # Submatrix diff --git a/test/Sparse/Matrix.jl b/test/Sparse/Matrix.jl index 89f00121c3..4ba86ad4a4 100644 --- a/test/Sparse/Matrix.jl +++ b/test/Sparse/Matrix.jl @@ -207,6 +207,39 @@ using Hecke.SparseArrays E = @inferred D * R(3) @test E == sparse_matrix(R, [3 0 0; 0 0 3; 0 0 0]) + # Dot product + + D = sparse_matrix(ZZ, [4 -2; -2 2]) + E = sparse_matrix(ZZ, [3 0 0; 0 3 0]) + + @test dot(sparse_row(ZZ, [1], [1]), D, sparse_row(ZZ, [1], [1])) == 4 + @test dot(sparse_row(ZZ, [1, 2], [1, 1]), D, sparse_row(ZZ, [1, 2], [1, 2])) == 2 + @test dot(sparse_row(ZZ, [1], [1]), D, sparse_row(ZZ, [2], [1])) == -2 + + @test dot(sparse_row(ZZ, [1, 4], [1, 2]), D, sparse_row(ZZ, [2], [1])) == -2 + @test dot(sparse_row(ZZ, [1, 4], [1, 2]), E, sparse_row(ZZ, [2], [1])) == 0 + @test dot(sparse_row(ZZ, [1, 2], [1, 2]), E, sparse_row(ZZ, [2], [1])) == 6 + + @test dot(ZZRingElem[1, 0], D, ZZRingElem[1, 0]) == 4 + @test dot(ZZRingElem[1, 1], D, ZZRingElem[1, 2]) == 2 + @test dot(ZZRingElem[1, 0], D, ZZRingElem[0, 1]) == -2 + @test dot(ZZRingElem[1, 0], E, ZZRingElem[0, 1, 2]) == 0 + @test dot(ZZRingElem[0, 1], E, ZZRingElem[0, 1, 2]) == 3 + + @test dot(ZZ[1 0], D, ZZ[1 0]) == 4 + @test dot(ZZ[1; 1], D, ZZ[1 2]) == 2 + @test dot(ZZ[1 0], D, ZZ[0; 1]) == -2 + @test dot(ZZ[1 0], E, ZZ[0 1 2]) == 0 + @test dot(ZZ[0 1], E, ZZ[0 1 2]) == 3 + + @test_throws ArgumentError dot(ZZRingElem[1], D, ZZRingElem[0, 1]) + @test_throws ArgumentError dot(ZZRingElem[1, 0], D, ZZRingElem[0]) + @test_throws ArgumentError dot(ZZRingElem[1, 0, 2], E, ZZRingElem[0, 1]) + + @test_throws ArgumentError dot(ZZ[1 0 0], D, ZZ[1 0]) + @test_throws ArgumentError dot(ZZ[0 1 2], E, ZZ[0 1]) + @test_throws ArgumentError dot(ZZ[1 0; 0 0], D, ZZ[1 0 0 0]) + # Submatrix D = sparse_matrix(FlintZZ, [1 5 3; 0 0 0; 0 1 0])