From d846989b1e3c04a80e1ae48a41d0e2a8c6ecc406 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20G=C3=B6ttgens?= Date: Wed, 27 Sep 2023 16:27:21 +0200 Subject: [PATCH 1/2] Add test --- test/Sparse/Matrix.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/Sparse/Matrix.jl b/test/Sparse/Matrix.jl index a21589269c..401eb34c92 100644 --- a/test/Sparse/Matrix.jl +++ b/test/Sparse/Matrix.jl @@ -315,6 +315,12 @@ using Hecke.SparseArrays E = SparseArrays.sparse(D) @test Matrix(E) == ZZRingElem[1 5 3; 0 -10 0; 0 1 0] @test Array(E) == ZZRingElem[1 5 3; 0 -10 0; 0 1 0] + + # kronecker_product + D1 = sparse_matrix(FlintZZ, [12 403 -23; 0 0 122; -1 2 99]) + D2 = sparse_matrix(FlintZZ, [81 0 2; 31 0 -5]) + E = @inferred kronecker_product(D1, D2) + @test E == sparse_matrix(kronecker_product(matrix(D1), matrix(D2))) end @testset "Oscar #2128" begin @@ -328,3 +334,8 @@ end S1 = sparse_matrix(ZZ,[-1 0; 0 -1]) @test S0 + S1 == sparse_matrix(ZZ, 2, 2) end + +@testset "Hecke #1227" begin + A = sparse_matrix(FlintZZ, [2 0; 0 0]) + @test kronecker_product(A, A) == sparse_matrix(FlintZZ, [4 0 0 0; 0 0 0 0; 0 0 0 0; 0 0 0 0]) +end From ea11246ddf6915cf9f3e882f3c883ae512273f51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20G=C3=B6ttgens?= Date: Wed, 27 Sep 2023 16:36:48 +0200 Subject: [PATCH 2/2] Fix `kronecker_product(::SMat, ::SMat)` --- src/Sparse/Matrix.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/Sparse/Matrix.jl b/src/Sparse/Matrix.jl index c510ebaa63..415a788510 100644 --- a/src/Sparse/Matrix.jl +++ b/src/Sparse/Matrix.jl @@ -1481,16 +1481,21 @@ function kronecker_product(A::SMat{T}, B::SMat{T}) where {T} for rB = B.rows p = Int[] v = T[] - o = (rA.pos[1]-1)*ncols(B) - for (pp, vv) = rA - for (qq, ww) = rB - push!(p, qq+o) - push!(v, vv*ww) + if !iszero(rA) + o = (rA.pos[1]-1)*ncols(B) + for (pp, vv) = rA + for (qq, ww) = rB + push!(p, qq+o) + push!(v, vv*ww) + end + o += ncols(B) end - o += ncols(B) end push!(C, sparse_row(base_ring(A), p, v)) end end + @assert nrows(C) == nrows(A)*nrows(B) + @assert ncols(C) <= ncols(A)*ncols(B) + C.c = ncols(A)*ncols(B) return C end