Skip to content

Commit

Permalink
Matmul: matprod_dest for Diagonal * SymTridiagonal (#55039)
Browse files Browse the repository at this point in the history
We specialize `matprod_dest` for the combination of a `Diagonal` and a
`SymTridiaognal`, in which case the destination is a `Tridiagonal`. With
this, the specialized methods `*(::Diagonal, ::SymTridiagonal)` and
`*(::SymTridiagonal, ::Diagonal)` don't need to be defined anymore,
which reduces potential method ambiguities.
  • Loading branch information
jishnub authored Jul 6, 2024
1 parent e318166 commit 1837202
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
13 changes: 10 additions & 3 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -669,9 +669,16 @@ matprod_dest(A, B::StructuredMatrix, TS) = similar(A, TS, size(A))
matprod_dest(A::StructuredMatrix, B, TS) = similar(B, TS, size(B))
# diagonal is special, as it does not change the structure of the other matrix
# we call similar without a size to preserve the type of the matrix wherever possible
matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = similar(A, TS)
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = similar(B, TS)
matprod_dest(A::Diagonal, B::Diagonal, TS) = similar(B, TS)
# reroute through _matprod_dest_diag to allow speicalizing on the type of the StructuredMatrix
# without defining methods for both the orderings
matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = _matprod_dest_diag(A, TS)
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = _matprod_dest_diag(B, TS)
matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS)
_matprod_dest_diag(A, TS) = similar(A, TS)
function _matprod_dest_diag(A::SymTridiagonal, TS)
n = size(A, 1)
Tridiagonal(similar(A, TS, n-1), similar(A, TS, n), similar(A, TS, n-1))
end

# Special handling for adj/trans vec
matprod_dest(A::Diagonal, B::AdjOrTransAbsVec, TS) = similar(B, TS)
Expand Down
12 changes: 0 additions & 12 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -751,18 +751,6 @@ function *(A::Bidiagonal, B::LowerOrUnitLowerTriangular)
return A.uplo == 'L' ? LowerTriangular(C) : C
end

function *(A::Diagonal, B::SymTridiagonal)
TS = promote_op(*, eltype(A), eltype(B))
out = Tridiagonal(similar(A, TS, size(A, 1)-1), similar(A, TS, size(A, 1)), similar(A, TS, size(A, 1)-1))
mul!(out, A, B)
end

function *(A::SymTridiagonal, B::Diagonal)
TS = promote_op(*, eltype(A), eltype(B))
out = Tridiagonal(similar(A, TS, size(A, 1)-1), similar(A, TS, size(A, 1)), similar(A, TS, size(A, 1)-1))
mul!(out, A, B)
end

function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector)
require_one_based_indexing(x, y)
nx, ny = length(x), length(y)
Expand Down

0 comments on commit 1837202

Please sign in to comment.