Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stdlib: faster kronecker product between hermitian and symmetric matrices #53186

Merged
merged 10 commits into from
Apr 18, 2024
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,8 @@ julia> reshape(kron(v,w), (length(w), length(v)))
```
"""
function kron(A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) where {T,S}
R = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B))
return kron!(R, A, B)
C = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B))
return kron!(C, A, B)
end
function kron(a::AbstractVector{T}, b::AbstractVector{S}) where {T,S}
c = Vector{promote_op(*,T,S)}(undef, length(a)*length(b))
Expand Down
124 changes: 124 additions & 0 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,130 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni
end
end

function kron(A::Hermitian{T}, B::Hermitian{S}) where {T<:Union{Real,Complex},S<:Union{Real,Complex}}
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
C = Hermitian(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
return kron!(C, A, B)
end

function kron(A::Symmetric{T}, B::Symmetric{S}) where {T<:Number,S<:Number}
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
C = Symmetric(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
return kron!(C, A, B)
end

function kron!(C::Hermitian{<:Union{Real,Complex}}, A::Hermitian{<:Union{Real,Complex}}, B::Hermitian{<:Union{Real,Complex}})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
end
_hermkron!(C.data, A.data, B.data, conj, real, A.uplo, B.uplo)
return C
end

function kron!(C::Symmetric{<:Number}, A::Symmetric{<:Number}, B::Symmetric{<:Number})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
end
_hermkron!(C.data, A.data, B.data, identity, identity, A.uplo, B.uplo)
return C
end

function _hermkron!(C, A, B, conj::TC, real::TR, Auplo, Buplo) where {TC,TR}
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds if Auplo == 'U' && Buplo == 'U'
for j = 1:n_A
jnB = (j - 1) * n_B
for i = 1:(j-1)
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
for k = 1:(l-1)
C[inB+k, jnB+l] = Aij * B[k, l]
C[inB+l, jnB+k] = Aij * conj(B[k, l])
end
C[inB+l, jnB+l] = Aij * real(B[l, l])
end
end
Ajj = real(A[j, j])
for l = 1:n_B
for k = 1:(l-1)
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
end
end
elseif Auplo == 'U' && Buplo == 'L'
for j = 1:n_A
jnB = (j - 1) * n_B
for i = 1:(j-1)
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
C[inB+l, jnB+l] = Aij * real(B[l, l])
for k = (l+1):n_B
C[inB+l, jnB+k] = Aij * conj(B[k, l])
C[inB+k, jnB+l] = Aij * B[k, l]
end
end
end
Ajj = real(A[j, j])
for l = 1:n_B
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
for k = (l+1):n_B
C[jnB+l, jnB+k] = Ajj * conj(B[k, l])
end
end
end
elseif Auplo == 'L' && Buplo == 'U'
for j = 1:n_A
jnB = (j - 1) * n_B
Ajj = real(A[j, j])
for l = 1:n_B
for k = 1:(l-1)
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
end
for i = (j+1):n_A
conjAij = conj(A[i, j])
inB = (i - 1) * n_B
for l = 1:n_B
for k = 1:(l-1)
C[jnB+k, inB+l] = conjAij * B[k, l]
C[jnB+l, inB+k] = conjAij * conj(B[k, l])
end
C[jnB+l, inB+l] = conjAij * real(B[l, l])
end
end
end
else #if Auplo == 'L' && Buplo == 'L'
for j = 1:n_A
jnB = (j - 1) * n_B
Ajj = real(A[j, j])
for l = 1:n_B
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
for k = (l+1):n_B
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
end
for i = (j+1):n_A
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
C[inB+l, jnB+l] = Aij * real(B[l, l])
for k = (l+1):n_B
C[inB+k, jnB+l] = Aij * B[k, l]
C[inB+l, jnB+k] = Aij * conj(B[k, l])
end
end
end
end
end
end

(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo))
(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo))

Expand Down
74 changes: 74 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,80 @@ for op in (:+, :-)
end
end

function kron(A::UpperTriangular{T}, B::UpperTriangular{S}) where {T<:Number,S<:Number}
C = UpperTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
return kron!(C, A, B)
end

function kron(A::LowerTriangular{T}, B::LowerTriangular{S}) where {T<:Number,S<:Number}
C = LowerTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
return kron!(C, A, B)
end

function kron!(C::UpperTriangular{<:Number}, A::UpperTriangular{<:Number}, B::UpperTriangular{<:Number})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
_triukron!(C.data, A.data, B.data)
return C
end

function kron!(C::LowerTriangular{<:Number}, A::LowerTriangular{<:Number}, B::LowerTriangular{<:Number})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
_trilkron!(C.data, A.data, B.data)
return C
end

function _triukron!(C, A, B)
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds for j = 1:n_A
jnB = (j - 1) * n_B
for i = 1:(j-1)
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
for k = 1:l
C[inB+k, jnB+l] = Aij * B[k, l]
end
for k = 1:(l-1)
C[inB+l, jnB+k] = zero(eltype(C))
end
end
end
Ajj = A[j, j]
for l = 1:n_B
for k = 1:l
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
end
end
end

function _trilkron!(C, A, B)
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds for j = 1:n_A
jnB = (j - 1) * n_B
Ajj = A[j, j]
for l = 1:n_B
for k = l:n_B
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
end
for i = (j+1):n_A
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
for k = l:n_B
C[inB+k, jnB+l] = Aij * B[k, l]
end
for k = (l+1):n_B
C[inB+l, jnB+k] = zero(eltype(C))
jishnub marked this conversation as resolved.
Show resolved Hide resolved
end
end
end
end
end

######################
# BlasFloat routines #
######################
Expand Down
35 changes: 35 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,28 @@ end
@test dot(symblockml, symblockml) ≈ dot(msymblockml, msymblockml)
end
end

@testset "kronecker product of symmetric and Hermitian matrices" begin
for mtype in (Symmetric, Hermitian)
symau = mtype(a, :U)
symal = mtype(a, :L)
msymau = Matrix(symau)
msymal = Matrix(symal)
for eltyc in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Int)
creal = randn(n, n)/2
cimag = randn(n, n)/2
c = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(creal, cimag) : creal)
symcu = mtype(c, :U)
symcl = mtype(c, :L)
msymcu = Matrix(symcu)
msymcl = Matrix(symcl)
@test kron(symau, symcu) ≈ kron(msymau, msymcu)
@test kron(symau, symcl) ≈ kron(msymau, msymcl)
@test kron(symal, symcu) ≈ kron(msymal, msymcu)
@test kron(symal, symcl) ≈ kron(msymal, msymcl)
end
end
end
end
end

Expand All @@ -487,6 +509,7 @@ end
@test S - S == MS - MS
@test S*2 == 2*S == 2*MS
@test S/2 == MS/2
@test kron(S,S) == kron(MS,MS)
end
@testset "mixed uplo" begin
Mu = Matrix{Complex{BigFloat}}(undef,2,2)
Expand All @@ -502,6 +525,8 @@ end
MSl = Matrix(Sl)
@test Su + Sl == Sl + Su == MSu + MSl
@test Su - Sl == -(Sl - Su) == MSu - MSl
@test kron(Su,Sl) == kron(MSu,MSl)
@test kron(Sl,Su) == kron(MSl,MSu)
end
end
end
Expand All @@ -517,6 +542,16 @@ end
@test dot(A, B) ≈ dot(Symmetric(A), Symmetric(B))
end

# let's make sure the analogous bug will not show up with kronecker products
@testset "kron Hermitian quaternion #52318" begin
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + t' for i in 1:2]
@test A == Hermitian(A) && B == Hermitian(B)
@test kron(A, B) ≈ kron(Hermitian(A), Hermitian(B))
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + transpose(t) for i in 1:2]
@test A == Symmetric(A) && B == Symmetric(B)
@test kron(A, B) ≈ kron(Symmetric(A), Symmetric(B))
end

#Issue #7647: test xsyevr, xheevr, xstevr drivers.
@testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in
(Symmetric(diagm(0 => 1.0:3.0)),
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ debug && println("Test basic type functionality")
# Binary operations
@test A1 + A2 == M1 + M2
@test A1 - A2 == M1 - M2
@test kron(A1,A2) == kron(M1,M2)

# Triangular-Triangular multiplication and division
@test A1*A2 ≈ M1*M2
Expand Down Expand Up @@ -1014,6 +1015,7 @@ end
@test 2\L == 2\B
@test real(L) == real(B)
@test imag(L) == imag(B)
@test kron(L,L) == kron(B,B)
@test transpose!(MT(copy(A))) == transpose(L) broken=!(A isa Matrix)
@test adjoint!(MT(copy(A))) == adjoint(L) broken=!(A isa Matrix)
end
Expand All @@ -1035,6 +1037,7 @@ end
@test 2\U == 2\B
@test real(U) == real(B)
@test imag(U) == imag(B)
@test kron(U,U) == kron(B,B)
@test transpose!(MT(copy(A))) == transpose(U) broken=!(A isa Matrix)
@test adjoint!(MT(copy(A))) == adjoint(U) broken=!(A isa Matrix)
end
Expand Down