Skip to content

Commit

Permalink
specialize copyto! and multiplication by numbers for Q from qr (#39533)
Browse files Browse the repository at this point in the history
* specialize copyto! and multiplication by numbers for Q from qr

This fixes two performance bugs reported in https://github.com/JuliaLang/julia/issues/38972
and https://github.com/JuliaLang/julia/issues/38972 (multiplication of `Q` from `qr` by a
`Diagonal` or `UniformScaling`). In particular, it improves the performance of generating
random orthogonal matrices as described in https://discourse.julialang.org/t/random-orthogonal-matrices/9779/7.

* fix typo in new qr tests

* resolve mehod ambiguity of copyto!
  • Loading branch information
ranocha authored Feb 23, 2021
1 parent 3230aef commit 8e949d6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
39 changes: 39 additions & 0 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,31 @@ function getindex(Q::AbstractQ, i::Integer, j::Integer)
return dot(x, lmul!(Q, y))
end

# specialization avoiding the fallback using slow `getindex`
function copyto!(dest::AbstractMatrix, src::AbstractQ)
copyto!(dest, I)
lmul!(src, dest)
end
# needed to resolve method ambiguities
function copyto!(dest::PermutedDimsArray{T,2,perm}, src::AbstractQ) where {T,perm}
if perm == (1, 2)
copyto!(parent(dest), src)
else
@assert perm == (2, 1) # there are no other permutations of two indices
if T <: Real
copyto!(parent(dest), I)
lmul!(src', parent(dest))
else
# LAPACK does not offer inplace lmul!(transpose(Q), B) for complex Q
tmp = similar(parent(dest))
copyto!(tmp, I)
rmul!(tmp, src)
permutedims!(parent(dest), tmp, (2, 1))
end
end
return dest
end

## Multiplication by Q
### QB
lmul!(A::QRCompactWYQ{T,S}, B::StridedVecOrMat{T}) where {T<:BlasFloat, S<:StridedMatrix} =
Expand Down Expand Up @@ -590,6 +615,13 @@ function (*)(A::AbstractQ, B::StridedMatrix)
lmul!(Anew, Bnew)
end

function (*)(A::AbstractQ, b::Number)
TAb = promote_type(eltype(A), typeof(b))
dest = similar(A, TAb)
copyto!(dest, b*I)
lmul!(A, dest)
end

### QcB
lmul!(adjA::Adjoint{<:Any,<:QRCompactWYQ{T,S}}, B::StridedVecOrMat{T}) where {T<:BlasReal,S<:StridedMatrix} =
(A = adjA.parent; LAPACK.gemqrt!('L','T',A.factors,A.T,B))
Expand Down Expand Up @@ -683,6 +715,13 @@ function (*)(A::StridedMatrix, Q::AbstractQ)
return rmul!(copy_oftype(A, TAQ), convert(AbstractMatrix{TAQ}, Q))
end

function (*)(a::Number, B::AbstractQ)
TaB = promote_type(typeof(a), eltype(B))
dest = similar(B, TaB)
copyto!(dest, a*I)
rmul!(dest, B)
end

### AQc
rmul!(A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:QRCompactWYQ{T}}) where {T<:BlasReal} =
(B = adjB.parent; LAPACK.gemqrt!('R','T',B.factors,B.T,A))
Expand Down
49 changes: 49 additions & 0 deletions stdlib/LinearAlgebra/test/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,53 @@ end
end
end

@testset "QR factorization of Q" begin
for T in (Float32, Float64, ComplexF32, ComplexF64)
Q1, R1 = qr(randn(T,5,5))
Q2, R2 = qr(Q1)
@test Q1 Q2
@test R2 I
end
end

@testset "Generation of orthogonal matrices" begin
for T in (Float32, Float64)
n = 5
Q, R = qr(randn(T,n,n))
O = Q * Diagonal(sign.(diag(R)))
@test O' * O I
end
end

@testset "Multiplication of Q by special matrices" begin
for T in (Float32, Float64, ComplexF32, ComplexF64)
n = 5
Q, R = qr(randn(T,n,n))
Qmat = Matrix(Q)
D = Diagonal(randn(T,n))
@test Q * D Qmat * D
@test D * Q D * Qmat
J = 2*I
@test Q * J Qmat * J
@test J * Q J * Qmat
end
end

@testset "copyto! for Q" begin
for T in (Float32, Float64, ComplexF32, ComplexF64)
n = 5
Q, R = qr(randn(T,n,n))
Qmat = Matrix(Q)
dest1 = similar(Q)
copyto!(dest1, Q)
@test dest1 Qmat
dest2 = PermutedDimsArray(similar(Q), (1, 2))
copyto!(dest2, Q)
@test dest2 Qmat
dest3 = PermutedDimsArray(similar(Q), (2, 1))
copyto!(dest3, Q)
@test dest3 Qmat
end
end

end # module TestQR

0 comments on commit 8e949d6

Please sign in to comment.