diff --git a/stdlib/LinearAlgebra/src/qr.jl b/stdlib/LinearAlgebra/src/qr.jl index a76577bb63a0d..c9ba49d2cd1ad 100644 --- a/stdlib/LinearAlgebra/src/qr.jl +++ b/stdlib/LinearAlgebra/src/qr.jl @@ -533,6 +533,31 @@ function getindex(Q::AbstractQ, i::Integer, j::Integer) return dot(x, lmul!(Q, y)) end +# specialization avoiding the fallback using slow `getindex` +function copyto!(dest::AbstractMatrix, src::AbstractQ) + copyto!(dest, I) + lmul!(src, dest) +end +# needed to resolve method ambiguities +function copyto!(dest::PermutedDimsArray{T,2,perm}, src::AbstractQ) where {T,perm} + if perm == (1, 2) + copyto!(parent(dest), src) + else + @assert perm == (2, 1) # there are no other permutations of two indices + if T <: Real + copyto!(parent(dest), I) + lmul!(src', parent(dest)) + else + # LAPACK does not offer inplace lmul!(transpose(Q), B) for complex Q + tmp = similar(parent(dest)) + copyto!(tmp, I) + rmul!(tmp, src) + permutedims!(parent(dest), tmp, (2, 1)) + end + end + return dest +end + ## Multiplication by Q ### QB lmul!(A::QRCompactWYQ{T,S}, B::StridedVecOrMat{T}) where {T<:BlasFloat, S<:StridedMatrix} = @@ -590,6 +615,13 @@ function (*)(A::AbstractQ, B::StridedMatrix) lmul!(Anew, Bnew) end +function (*)(A::AbstractQ, b::Number) + TAb = promote_type(eltype(A), typeof(b)) + dest = similar(A, TAb) + copyto!(dest, b*I) + lmul!(A, dest) +end + ### QcB lmul!(adjA::Adjoint{<:Any,<:QRCompactWYQ{T,S}}, B::StridedVecOrMat{T}) where {T<:BlasReal,S<:StridedMatrix} = (A = adjA.parent; LAPACK.gemqrt!('L','T',A.factors,A.T,B)) @@ -683,6 +715,13 @@ function (*)(A::StridedMatrix, Q::AbstractQ) return rmul!(copy_oftype(A, TAQ), convert(AbstractMatrix{TAQ}, Q)) end +function (*)(a::Number, B::AbstractQ) + TaB = promote_type(typeof(a), eltype(B)) + dest = similar(B, TaB) + copyto!(dest, a*I) + rmul!(dest, B) +end + ### AQc rmul!(A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:QRCompactWYQ{T}}) where {T<:BlasReal} = (B = adjB.parent; LAPACK.gemqrt!('R','T',B.factors,B.T,A)) diff --git a/stdlib/LinearAlgebra/test/qr.jl b/stdlib/LinearAlgebra/test/qr.jl index b4e6d383bb262..394b371e02eac 100644 --- a/stdlib/LinearAlgebra/test/qr.jl +++ b/stdlib/LinearAlgebra/test/qr.jl @@ -322,4 +322,53 @@ end end end +@testset "QR factorization of Q" begin + for T in (Float32, Float64, ComplexF32, ComplexF64) + Q1, R1 = qr(randn(T,5,5)) + Q2, R2 = qr(Q1) + @test Q1 ≈ Q2 + @test R2 ≈ I + end +end + +@testset "Generation of orthogonal matrices" begin + for T in (Float32, Float64) + n = 5 + Q, R = qr(randn(T,n,n)) + O = Q * Diagonal(sign.(diag(R))) + @test O' * O ≈ I + end +end + +@testset "Multiplication of Q by special matrices" begin + for T in (Float32, Float64, ComplexF32, ComplexF64) + n = 5 + Q, R = qr(randn(T,n,n)) + Qmat = Matrix(Q) + D = Diagonal(randn(T,n)) + @test Q * D ≈ Qmat * D + @test D * Q ≈ D * Qmat + J = 2*I + @test Q * J ≈ Qmat * J + @test J * Q ≈ J * Qmat + end +end + +@testset "copyto! for Q" begin + for T in (Float32, Float64, ComplexF32, ComplexF64) + n = 5 + Q, R = qr(randn(T,n,n)) + Qmat = Matrix(Q) + dest1 = similar(Q) + copyto!(dest1, Q) + @test dest1 ≈ Qmat + dest2 = PermutedDimsArray(similar(Q), (1, 2)) + copyto!(dest2, Q) + @test dest2 ≈ Qmat + dest3 = PermutedDimsArray(similar(Q), (2, 1)) + copyto!(dest3, Q) + @test dest3 ≈ Qmat + end +end + end # module TestQR