diff --git a/lib/GPUArraysCore/src/GPUArraysCore.jl b/lib/GPUArraysCore/src/GPUArraysCore.jl index 0d0b5182..d1a53c98 100644 --- a/lib/GPUArraysCore/src/GPUArraysCore.jl +++ b/lib/GPUArraysCore/src/GPUArraysCore.jl @@ -7,7 +7,7 @@ using Adapt export AbstractGPUArray, AbstractGPUVector, AbstractGPUMatrix, AbstractGPUVecOrMat, WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle, - AnyGPUArray, AnyGPUVector, AnyGPUMatrix + AnyGPUArray, AnyGPUVector, AnyGPUMatrix, AnyGPUVecOrMat """ AbstractGPUArray{T, N} <: DenseArray{T, N} @@ -27,6 +27,7 @@ const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{ const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}} const AnyGPUVector{T} = AnyGPUArray{T, 1} const AnyGPUMatrix{T} = AnyGPUArray{T, 2} +const AnyGPUVecOrMat{T} = Union{AnyGPUArray{T, 1}, AnyGPUArray{T, 2}} ## broadcasting diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 5619dd83..07c82a15 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -170,7 +170,7 @@ for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriang @eval Base.copyto!(A::$T{T, <:AbstractGPUArray{T,N}}, B::$T{T, <:AbstractGPUArray{T,N}}) where {T,N} = $T(copyto!(parent(A), parent(B))) end -function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T +function LinearAlgebra.tril!(A::AnyGPUMatrix{T}, d::Integer = 0) where T gpu_call(A, d; name="tril!") do ctx, _A, _d I = @cartesianidx _A i, j = Tuple(I) @@ -182,7 +182,7 @@ function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T return A end -function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T +function LinearAlgebra.triu!(A::AnyGPUMatrix{T}, d::Integer = 0) where T gpu_call(A, d; name="triu!") do ctx, _A, _d I = @cartesianidx _A i, j = Tuple(I) @@ -795,3 +795,21 @@ function Base.isone(x::AbstractGPUMatrix{T}) where {T} Array(y)[] end + +## QR + +import LinearAlgebra: QRPackedQ + +function LinearAlgebra.getproperty(F::QR{T,<:AnyGPUMatrix{T}}, d::Symbol) where {T} + m, n = size(F) + if d === :R + return triu!(view(getfield(F, :factors), 1:min(m,n), 1:n)) + elseif d === :Q + return QRPackedQ(getfield(F, :factors), F.τ) + else + getfield(F, d) + end +end + +Base.print_array(io::IO, Q::QRPackedQ{T,<:AnyGPUMatrix{T},<:AnyGPUMatrix{T}}) where {T} = + Base.print_array(io, collect(adapt(ToArray(), Q))) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index d84bb5bd..e3e73f68 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -378,3 +378,10 @@ end @test isrealfloattype(typeof(opnorm(AT(mat), p))) end end + +@testsuite "QR" (AT, eltypes)->begin + @testset "get property" for dims in [(3,5), (3,3), (5,3)], + prop in [:Q, :R], T in eltypes + @test compare(x -> getproperty(qr(x), prop), AT, rand(T, dims)) + end +end