Skip to content

Commit

Permalink
specialize copyto! and multiplication by numbers for Q from qr
Browse files Browse the repository at this point in the history
This fixes some performance problems reported in #38972 and #37102
(multiplying `Q` by a `Diagonal` or a `UniformScaling`). This improves
the performance of generating random orthogonal matrices as described
in https://discourse.julialang.org/t/random-orthogonal-matrices/9779/7
significantly.
  • Loading branch information
ranocha committed Feb 5, 2021
1 parent 0ea19e4 commit ecd925c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
20 changes: 20 additions & 0 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,12 @@ function getindex(Q::AbstractQ, i::Integer, j::Integer)
return dot(x, lmul!(Q, y))
end

# specialization avoiding the fallback using slow `getindex`
function copyto!(dest::AbstractMatrix, src::LinearAlgebra.AbstractQ)
copyto!(dest, I)
lmul!(src, dest)
end

## Multiplication by Q
### QB
lmul!(A::QRCompactWYQ{T,S}, B::StridedVecOrMat{T}) where {T<:BlasFloat, S<:StridedMatrix} =
Expand Down Expand Up @@ -590,6 +596,13 @@ function (*)(A::AbstractQ, B::StridedMatrix)
lmul!(Anew, Bnew)
end

function (*)(A::AbstractQ, b::Number)
TAb = promote_type(eltype(A), typeof(b))
dest = similar(A, TAb)
copyto!(dest, b*I)
lmul!(A, dest)
end

### QcB
lmul!(adjA::Adjoint{<:Any,<:QRCompactWYQ{T,S}}, B::StridedVecOrMat{T}) where {T<:BlasReal,S<:StridedMatrix} =
(A = adjA.parent; LAPACK.gemqrt!('L','T',A.factors,A.T,B))
Expand Down Expand Up @@ -683,6 +696,13 @@ function (*)(A::StridedMatrix, Q::AbstractQ)
return rmul!(copy_oftype(A, TAQ), convert(AbstractMatrix{TAQ}, Q))
end

function (*)(a::Number, B::AbstractQ)
TaB = promote_type(typeof(a), eltype(B))
dest = similar(B, TaB)
copyto!(dest, a*I)
rmul!(dest, B)
end

### AQc
rmul!(A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:QRCompactWYQ{T}}) where {T<:BlasReal} =
(B = adjB.parent; LAPACK.gemqrt!('R','T',B.factors,B.T,A))
Expand Down
32 changes: 32 additions & 0 deletions stdlib/LinearAlgebra/test/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,36 @@ end
end
end

@testset "QR factorization of Q" begin
for T in (Float32, Float64, ComplexF32, ComplexF64)
Q1, R1 = qr(randn(T,5,5))
Q2, R2 = qr(Q1)
@test Q1 Q2
@test R2 I
end
end

@testset "Generation of orthogonal matrices" begin
for T in (Float32, Float64)
n = 5
Q, R = qr(randn(T,n,n))
O = Q * Diagonal(sign.(diag(R)))
@test O' * O I
end
end

@testset "Multiplication of Q by special matrices" begin
for T in (Float32, Float64, ComplexF32, ComplexF64)
n = 5
Q, R = qr(randn(T,n,n))
Qmat = Matrix(Q)
D = Diagonal(randn(n))
@test Q * D Qmat * D
@test D * Q D * Qmat
J = 2*I
@test Q * J Qmat * J
@test J * Q J * Qmat
end
end

end # module TestQR

0 comments on commit ecd925c

Please sign in to comment.