Skip to content

Commit

Permalink
Preserve structure in scaling triangular matrices by NaN (#55310)
Browse files Browse the repository at this point in the history
Addresses the `Matrix` cases from
#55296. This restores the
behavior to match that on v1.10, and preserves the structure of the
matrix on scaling by `NaN`. This behavior is consistent with the
strong-zero behavior for other structured matrix types, and the scaling
may be seen to be occurring in the vector space corresponding to the
filled elements.

After this,
```julia
julia> UpperTriangular(rand(2,2)) * NaN
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 NaN    NaN
    ⋅   NaN
```
cc. @mikmoore
  • Loading branch information
jishnub authored Aug 1, 2024
1 parent 7c6a1a2 commit 0ef8a91
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
61 changes: 57 additions & 4 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,43 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
return A
end

function _trirdiv!(A::UpperTriangular, B::UpperOrUnitUpperTriangular, c::Number)
n = checksize1(A, B)
for j in 1:n
for i in 1:j
@inbounds A[i, j] = B[i, j] / c
end
end
return A
end
function _trirdiv!(A::LowerTriangular, B::LowerOrUnitLowerTriangular, c::Number)
n = checksize1(A, B)
for j in 1:n
for i in j:n
@inbounds A[i, j] = B[i, j] / c
end
end
return A
end
function _trildiv!(A::UpperTriangular, c::Number, B::UpperOrUnitUpperTriangular)
n = checksize1(A, B)
for j in 1:n
for i in 1:j
@inbounds A[i, j] = c \ B[i, j]
end
end
return A
end
function _trildiv!(A::LowerTriangular, c::Number, B::LowerOrUnitLowerTriangular)
n = checksize1(A, B)
for j in 1:n
for i in j:n
@inbounds A[i, j] = c \ B[i, j]
end
end
return A
end

rmul!(A::UpperOrLowerTriangular, c::Number) = @inline _triscale!(A, A, c, MulAddMul())
lmul!(c::Number, A::UpperOrLowerTriangular) = @inline _triscale!(A, c, A, MulAddMul())

Expand Down Expand Up @@ -1095,7 +1132,11 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular),
tstrided = t{<:Any, <:StridedMaybeAdjOrTransMat}
@eval begin
(*)(A::$t, x::Number) = $t(A.data*x)
(*)(A::$tstrided, x::Number) = A .* x
function (*)(A::$tstrided, x::Number)
eltype_dest = promote_op(*, eltype(A), typeof(x))
dest = $t(similar(parent(A), eltype_dest))
_triscale!(dest, x, A, MulAddMul())
end

function (*)(A::$unitt, x::Number)
B = $t(A.data)*x
Expand All @@ -1106,7 +1147,11 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular),
end

(*)(x::Number, A::$t) = $t(x*A.data)
(*)(x::Number, A::$tstrided) = x .* A
function (*)(x::Number, A::$tstrided)
eltype_dest = promote_op(*, typeof(x), eltype(A))
dest = $t(similar(parent(A), eltype_dest))
_triscale!(dest, x, A, MulAddMul())
end

function (*)(x::Number, A::$unitt)
B = x*$t(A.data)
Expand All @@ -1117,7 +1162,11 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular),
end

(/)(A::$t, x::Number) = $t(A.data/x)
(/)(A::$tstrided, x::Number) = A ./ x
function (/)(A::$tstrided, x::Number)
eltype_dest = promote_op(/, eltype(A), typeof(x))
dest = $t(similar(parent(A), eltype_dest))
_trirdiv!(dest, A, x)
end

function (/)(A::$unitt, x::Number)
B = $t(A.data)/x
Expand All @@ -1129,7 +1178,11 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular),
end

(\)(x::Number, A::$t) = $t(x\A.data)
(\)(x::Number, A::$tstrided) = x .\ A
function (\)(x::Number, A::$tstrided)
eltype_dest = promote_op(\, typeof(x), eltype(A))
dest = $t(similar(parent(A), eltype_dest))
_trildiv!(dest, x, A)
end

function (\)(x::Number, A::$unitt)
B = x\$t(A.data)
Expand Down
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1180,4 +1180,18 @@ end
@test V == Diagonal([1, 1])
end

@testset "preserve structure in scaling by NaN" begin
M = rand(Int8,2,2)
for (Ts, TD) in (((UpperTriangular, UnitUpperTriangular), UpperTriangular),
((LowerTriangular, UnitLowerTriangular), LowerTriangular))
for T in Ts
U = T(M)
for V in (U * NaN, NaN * U, U / NaN, NaN \ U)
@test V isa TD{Float64, Matrix{Float64}}
@test all(isnan, diag(V))
end
end
end
end

end # module TestTriangular

0 comments on commit 0ef8a91

Please sign in to comment.