Skip to content

Commit

Permalink
Merge pull request #24137 from Sacha0/repfullopts
Browse files Browse the repository at this point in the history
reduce allocation in a few linalg functions while removing full calls
  • Loading branch information
Sacha0 authored Oct 19, 2017
2 parents 9d3b6c7 + 3dbbf5b commit 4ae0155
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 33 deletions.
73 changes: 42 additions & 31 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ exp(A::StridedMatrix{<:Union{Integer,Complex{<:Integer}}}) = exp!(float.(A))
function exp!(A::StridedMatrix{T}) where T<:BlasFloat
n = checksquare(A)
if ishermitian(A)
return full(exp(Hermitian(A)))
return copytri!(parent(exp(Hermitian(A))), 'U', true)
end
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
nA = norm(A, 1)
Expand Down Expand Up @@ -601,13 +601,14 @@ julia> log(A)
function log(A::StridedMatrix)
# If possible, use diagonalization
if ishermitian(A)
return full(log(Hermitian(A)))
logHermA = log(Hermitian(A))
return isa(logHermA, Hermitian) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
end

# Use Schur decomposition
n = checksquare(A)
if istriu(A)
return full(log(UpperTriangular(complex(A))))
return triu!(parent(log(UpperTriangular(complex(A)))))
else
if isreal(A)
SchurF = schurfact(real(A))
Expand Down Expand Up @@ -658,27 +659,28 @@ julia> sqrt(A)
"""
function sqrt(A::StridedMatrix{<:Real})
if issymmetric(A)
return full(sqrt(Symmetric(A)))
return copytri!(parent(sqrt(Symmetric(A))), 'U')
end
n = checksquare(A)
if istriu(A)
return full(sqrt(UpperTriangular(A)))
return triu!(parent(sqrt(UpperTriangular(A))))
else
SchurF = schurfact(complex(A))
R = full(sqrt(UpperTriangular(SchurF[:T])))
R = triu!(parent(sqrt(UpperTriangular(SchurF[:T])))) # unwrapping unnecessary?
return SchurF[:vectors] * R * SchurF[:vectors]'
end
end
function sqrt(A::StridedMatrix{<:Complex})
if ishermitian(A)
return full(sqrt(Hermitian(A)))
sqrtHermA = sqrt(Hermitian(A))
return isa(sqrtHermA, Hermitian) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
end
n = checksquare(A)
if istriu(A)
return full(sqrt(UpperTriangular(A)))
return triu!(parent(sqrt(UpperTriangular(A))))
else
SchurF = schurfact(A)
R = full(sqrt(UpperTriangular(SchurF[:T])))
R = triu!(parent(sqrt(UpperTriangular(SchurF[:T])))) # unwrapping unnecessary?
return SchurF[:vectors] * R * SchurF[:vectors]'
end
end
Expand Down Expand Up @@ -716,13 +718,13 @@ julia> cos(ones(2, 2))
"""
function cos(A::AbstractMatrix{<:Real})
if issymmetric(A)
return full(cos(Symmetric(A)))
return copytri!(parent(cos(Symmetric(A))), 'U')
end
return real(exp!(im*A))
end
function cos(A::AbstractMatrix{<:Complex})
if ishermitian(A)
return full(cos(Hermitian(A)))
return copytri!(parent(cos(Hermitian(A))), 'U', true)
end
X = exp!(im*A)
X .= (X .+ exp!(-im*A)) ./ 2
Expand All @@ -747,13 +749,13 @@ julia> sin(ones(2, 2))
"""
function sin(A::AbstractMatrix{<:Real})
if issymmetric(A)
return full(sin(Symmetric(A)))
return copytri!(parent(sin(Symmetric(A))), 'U')
end
return imag(exp!(im*A))
end
function sin(A::AbstractMatrix{<:Complex})
if ishermitian(A)
return full(sin(Hermitian(A)))
return copytri!(parent(sin(Hermitian(A))), 'U', true)
end
X = exp!(im*A)
Y = exp!(-im*A)
Expand Down Expand Up @@ -786,14 +788,20 @@ julia> C
"""
function sincos(A::AbstractMatrix{<:Real})
if issymmetric(A)
return full.(sincos(Symmetric(A)))
symsinA, symcosA = sincos(Symmetric(A))
sinA = copytri!(parent(symsinA), 'U')
cosA = copytri!(parent(symcosA), 'U')
return sinA, cosA
end
c, s = reim(exp!(im*A))
return s, c
end
function sincos(A::AbstractMatrix{<:Complex})
if ishermitian(A)
return full.(sincos(Hermitian(A)))
hermsinA, hermcosA = sincos(Hermitian(A))
sinA = copytri!(parent(hermsinA), 'U', true)
cosA = copytri!(parent(hermcosA), 'U', true)
return sinA, cosA
end
X = exp!(im*A)
Y = exp!(-im*A)
Expand Down Expand Up @@ -823,7 +831,7 @@ julia> tan(ones(2, 2))
"""
function tan(A::AbstractMatrix)
if ishermitian(A)
return full(tan(Hermitian(A)))
return copytri!(parent(tan(Hermitian(A))), 'U', true)
end
S, C = sincos(A)
S /= C
Expand All @@ -837,7 +845,7 @@ Compute the matrix hyperbolic cosine of a square matrix `A`.
"""
function cosh(A::AbstractMatrix)
if ishermitian(A)
return full(cosh(Hermitian(A)))
return copytri!(parent(cosh(Hermitian(A))), 'U', true)
end
X = exp(A)
X .= (X .+ exp!(-A)) ./ 2
Expand All @@ -851,7 +859,7 @@ Compute the matrix hyperbolic sine of a square matrix `A`.
"""
function sinh(A::AbstractMatrix)
if ishermitian(A)
return full(sinh(Hermitian(A)))
return copytri!(parent(sinh(Hermitian(A))), 'U', true)
end
X = exp(A)
X .= (X .- exp!(-A)) ./ 2
Expand All @@ -865,7 +873,7 @@ Compute the matrix hyperbolic tangent of a square matrix `A`.
"""
function tanh(A::AbstractMatrix)
if ishermitian(A)
return full(tanh(Hermitian(A)))
return copytri!(parent(tanh(Hermitian(A))), 'U', true)
end
X = exp(A)
Y = exp!(-A)
Expand Down Expand Up @@ -900,11 +908,12 @@ julia> acos(cos([0.5 0.1; -0.2 0.3]))
"""
function acos(A::AbstractMatrix)
if ishermitian(A)
return full(acos(Hermitian(A)))
acosHermA = acos(Hermitian(A))
return isa(acosHermA, Hermitian) ? copytri!(parent(acosHermA), 'U', true) : parent(acosHermA)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(-im * log(U + im * sqrt(I - U^2)))
R = triu!(parent(-im * log(U + im * sqrt(I - U^2))))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -930,11 +939,12 @@ julia> asin(sin([0.5 0.1; -0.2 0.3]))
"""
function asin(A::AbstractMatrix)
if ishermitian(A)
return full(asin(Hermitian(A)))
asinHermA = asin(Hermitian(A))
return isa(asinHermA, Hermitian) ? copytri!(parent(asinHermA), 'U', true) : parent(asinHermA)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(-im * log(im * U + sqrt(I - U^2)))
R = triu!(parent(-im * log(im * U + sqrt(I - U^2))))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -960,11 +970,11 @@ julia> atan(tan([0.5 0.1; -0.2 0.3]))
"""
function atan(A::AbstractMatrix)
if ishermitian(A)
return full(atan(Hermitian(A)))
return copytri!(parent(atan(Hermitian(A))), 'U', true)
end
SchurF = schurfact(complex(A))
U = im * UpperTriangular(SchurF.T)
R = full(log((I + U) / (I - U)) / 2im)
R = triu!(parent(log((I + U) / (I - U)) / 2im))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -978,11 +988,12 @@ logarithmic formulas used to compute this function, see [^AH16_4].
"""
function acosh(A::AbstractMatrix)
if ishermitian(A)
return full(acosh(Hermitian(A)))
acoshHermA = acosh(Hermitian(A))
return isa(acoshHermA, Hermitian) ? copytri!(parent(acoshHermA), 'U', true) : parent(acoshHermA)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(log(U + sqrt(U - I) * sqrt(U + I)))
R = triu!(parent(log(U + sqrt(U - I) * sqrt(U + I))))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -996,11 +1007,11 @@ logarithmic formulas used to compute this function, see [^AH16_5].
"""
function asinh(A::AbstractMatrix)
if ishermitian(A)
return full(asinh(Hermitian(A)))
return copytri!(parent(asinh(Hermitian(A))), 'U', true)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(log(U + sqrt(I + U^2)))
R = triu!(parent(log(U + sqrt(I + U^2))))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -1014,11 +1025,11 @@ logarithmic formulas used to compute this function, see [^AH16_6].
"""
function atanh(A::AbstractMatrix)
if ishermitian(A)
return full(atanh(Hermitian(A)))
return copytri!(parent(atanh(Hermitian(A))), 'U', true)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(log((I + U) / (I - U)) / 2)
R = triu!(parent(log((I + U) / (I - U)) / 2))
return SchurF.Z * R * SchurF.Z'
end

Expand Down
4 changes: 2 additions & 2 deletions base/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular),
end

function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular})
ULnew = similar(full(UL), promote_type(eltype(J), eltype(UL)))
ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL)))
n = size(ULnew, 1)
ULold = UL.data
for j = 1:n
Expand All @@ -126,7 +126,7 @@ function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular})
return UpperTriangular(ULnew)
end
function (-)(J::UniformScaling, UL::Union{LowerTriangular,UnitLowerTriangular})
ULnew = similar(full(UL), promote_type(eltype(J), eltype(UL)))
ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL)))
n = size(ULnew, 1)
ULold = UL.data
for j = 1:n
Expand Down

0 comments on commit 4ae0155

Please sign in to comment.