Skip to content

Commit

Permalink
Specialize adding/subtracting mixed Upper/LowerTriangular (#56149)
Browse files Browse the repository at this point in the history
Fixes JuliaLang/julia#56134
After this,
```julia
julia> using LinearAlgebra

julia> A = hermitianpart(rand(4, 4))
4×4 Hermitian{Float64, Matrix{Float64}}:
 0.387617  0.277226  0.67629   0.60678
 0.277226  0.894101  0.388416  0.489141
 0.67629   0.388416  0.100907  0.619955
 0.60678   0.489141  0.619955  0.452605

julia> B = UpperTriangular(A)
4×4 UpperTriangular{Float64, Hermitian{Float64, Matrix{Float64}}}:
 0.387617  0.277226  0.67629   0.60678
  ⋅        0.894101  0.388416  0.489141
  ⋅         ⋅        0.100907  0.619955
  ⋅         ⋅         ⋅        0.452605

julia> B - B'
4×4 Matrix{Float64}:
  0.0        0.277226   0.67629   0.60678
 -0.277226   0.0        0.388416  0.489141
 -0.67629   -0.388416   0.0       0.619955
 -0.60678   -0.489141  -0.619955  0.0
```
This preserves the band structure of the parent, if any:
```julia
julia> U = UpperTriangular(Diagonal(ones(4)))
4×4 UpperTriangular{Float64, Diagonal{Float64, Vector{Float64}}}:
 1.0  0.0  0.0  0.0
  ⋅   1.0  0.0  0.0
  ⋅    ⋅   1.0  0.0
  ⋅    ⋅    ⋅   1.0

julia> U - U'
4×4 Diagonal{Float64, Vector{Float64}}:
 0.0   ⋅    ⋅    ⋅ 
  ⋅   0.0   ⋅    ⋅ 
  ⋅    ⋅   0.0   ⋅ 
  ⋅    ⋅    ⋅   0.0
```
This doesn't fully work with partly initialized matrices, and would need
JuliaLang/julia#55312 for that.

The abstract triangular methods now construct matrices using
`similar(parent(U), size(U))` so that the destinations are fully
mutable.
```julia
julia> @invoke B::LinearAlgebra.AbstractTriangular - B'::LinearAlgebra.AbstractTriangular
4×4 Matrix{Float64}:
  0.0        0.277226   0.67629   0.60678
 -0.277226   0.0        0.388416  0.489141
 -0.67629   -0.388416   0.0       0.619955
 -0.60678   -0.489141  -0.619955  0.0
```

---------

Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
jishnub and dkarrasch authored Oct 18, 2024
1 parent f4b76c0 commit fd8f17a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ UnitUpperTriangular
const UpperOrUnitUpperTriangular{T,S} = Union{UpperTriangular{T,S}, UnitUpperTriangular{T,S}}
const LowerOrUnitLowerTriangular{T,S} = Union{LowerTriangular{T,S}, UnitLowerTriangular{T,S}}
const UpperOrLowerTriangular{T,S} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}}
const UnitUpperOrUnitLowerTriangular{T,S} = Union{UnitUpperTriangular{T,S}, UnitLowerTriangular{T,S}}

uppertriangular(M) = UpperTriangular(M)
lowertriangular(M) = LowerTriangular(M)
Expand Down Expand Up @@ -181,6 +182,16 @@ copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copyto!(si

# then handle all methods that requires specific handling of upper/lower and unit diagonal

function full(A::Union{UpperTriangular,LowerTriangular})
return _triangularize(A)(parent(A))
end
function full(A::UnitUpperOrUnitLowerTriangular)
isupper = A isa UnitUpperTriangular
Ap = _triangularize(A)(parent(A), isupper ? 1 : -1)
Ap[diagind(Ap, IndexStyle(Ap))] = @view A[diagind(A, IndexStyle(A))]
return Ap
end

function full!(A::LowerTriangular)
B = A.data
tril!(B)
Expand Down Expand Up @@ -571,6 +582,8 @@ end
return A
end

_triangularize(::UpperOrUnitUpperTriangular) = triu
_triangularize(::LowerOrUnitLowerTriangular) = tril
_triangularize!(::UpperOrUnitUpperTriangular) = triu!
_triangularize!(::LowerOrUnitLowerTriangular) = tril!

Expand Down Expand Up @@ -880,7 +893,8 @@ function +(A::UnitLowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
end
+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) + copyto!(similar(parent(B)), B)
+(A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) = full(A) + full(B)
+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A), size(A)), A) + copyto!(similar(parent(B), size(B)), B)

function -(A::UpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
Expand Down Expand Up @@ -914,7 +928,8 @@ function -(A::UnitLowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
end
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)
-(A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) = full(A) - full(B)
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A), size(A)), A) - copyto!(similar(parent(B), size(B)), B)

function kron(A::UpperTriangular{T,<:StridedMaybeAdjOrTransMat}, B::UpperTriangular{S,<:StridedMaybeAdjOrTransMat}) where {T,S}
C = UpperTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
Expand Down
43 changes: 43 additions & 0 deletions test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1322,4 +1322,47 @@ end
end
end

@testset "addition/subtraction of mixed triangular" begin
for A in (Hermitian(rand(4, 4)), Diagonal(rand(5)))
for T in (UpperTriangular, LowerTriangular,
UnitUpperTriangular, UnitLowerTriangular)
B = T(A)
M = Matrix(B)
R = B - B'
if A isa Diagonal
@test R isa Diagonal
end
@test R == M - M'
R = B + B'
if A isa Diagonal
@test R isa Diagonal
end
@test R == M + M'
C = MyTriangular(B)
@test C - C' == M - M'
@test C + C' == M + M'
end
end
@testset "unfilled parent" begin
@testset for T in (UpperTriangular, LowerTriangular,
UnitUpperTriangular, UnitLowerTriangular)
F = Matrix{BigFloat}(undef, 2, 2)
B = T(F)
isupper = B isa Union{UpperTriangular, UnitUpperTriangular}
B[1+!isupper, 1+isupper] = 2
if !(B isa Union{UnitUpperTriangular, UnitLowerTriangular})
B[1,1] = B[2,2] = 3
end
M = Matrix(B)
@test B - B' == M - M'
@test B + B' == M + M'
@test B - copy(B') == M - M'
@test B + copy(B') == M + M'
C = MyTriangular(B)
@test C - C' == M - M'
@test C + C' == M + M'
end
end
end

end # module TestTriangular

0 comments on commit fd8f17a

Please sign in to comment.