Skip to content

Commit

Permalink
support non-blas types (#51)
Browse files Browse the repository at this point in the history
* hessenberg merged

* hessenberg merged

* pfaffian.jl

* pfaffian.jl

* non blas type

* non blas type

* non blas type

* eigen complex friendly

* eigen complex friendly
  • Loading branch information
smataigne authored Aug 15, 2022
1 parent 54067de commit 2820ca4
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 118 deletions.
57 changes: 49 additions & 8 deletions src/eigen.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,47 @@
# Based on eigen.jl in Julia. License is MIT: https://julialang.org/license


@views function LA.eigvals!(A::SkewHermitian, sortby::Union{Function,Nothing}=nothing)
@views function LA.eigvals!(A::SkewHermitian{<:Real}, sortby::Union{Function,Nothing}=nothing)
vals = skeweigvals!(A)
!isnothing(sortby) && sort!(vals, by=sortby)
return complex.(0, vals)
end

@views function LA.eigvals!(A::SkewHermitian, irange::UnitRange)
@views function LA.eigvals!(A::SkewHermitian{<:Real}, irange::UnitRange)
vals = skeweigvals!(A,irange)
return complex.(0, vals)
end

@views function LA.eigvals!(A::SkewHermitian, vl::Real,vh::Real)
@views function LA.eigvals!(A::SkewHermitian{<:Real}, vl::Real,vh::Real)
vals = skeweigvals!(A,-vh,-vl)
return complex.(0, vals)
end

@views function LA.eigvals!(A::SkewHermitian{<:Complex}, sortby::Union{Function,Nothing}=nothing)
H=Hermitian(A.data.*1im)
if sortby===nothing
return complex.(0, - eigvals!(H))
end
vals=eigvals!(H,sortby)
reverse!(vals)
vals.= .-vals
return complex.(0, vals)
end

@views function LA.eigvals!(A::SkewHermitian{<:Complex}, irange::UnitRange)
H=Hermitian(A.data.*1im)
vals=eigvals!(H,-irange)
vals.= .-vals
return complex.(0, vals)
end

@views function LA.eigvals!(A::SkewHermitian{<:Complex}, vl::Real,vh::Real)
H=Hermitian(A.data.*1im)
vals=eigvals!(H,-vh,-vl)
vals.= .-vals
return complex.(0, vals)
end

LA.eigvals(A::SkewHermitian, irange::UnitRange) =
LA.eigvals!(copyeigtype(A), irange)
LA.eigvals(A::SkewHermitian, vl::Real,vh::Real) =
Expand Down Expand Up @@ -102,24 +127,34 @@ end
end


@views function LA.eigen!(A::SkewHermitian)
@views function LA.eigen!(A::SkewHermitian{<:Real})
vals,Qr,Qim = skeweigen!(A)
return Eigen(vals,complex.(Qr,Qim))
end

copyeigtype(A::SkewHermitian) = copyto!(similar(A, LA.eigtype(eltype(A))), A)

@views function LA.eigen!(A::SkewHermitian{T}) where {T<:Complex}
H=Hermitian(A.data.*1im)
Eig=eigen!(H)
skew_Eig=Eigen(complex.(0,-Eig.values), Eig.vectors)
return skew_Eig
end

LA.eigen(A::SkewHermitian) = LA.eigen!(copyeigtype(A))

@views function LA.svdvals!(A::SkewHermitian)
n=size(A,1)
@views function LA.svdvals!(A::SkewHermitian{<:Real})
vals = skeweigvals!(A)
vals .= abs.(vals)
return sort!(vals; rev=true)
end
@views function LA.svdvals!(A::SkewHermitian{<:Complex})
H=Hermitian(A.data.*1im)
return svdvals!(H)
end
LA.svdvals(A::SkewHermitian) = svdvals!(copyeigtype(A))

@views function LA.svd!(A::SkewHermitian)
@views function LA.svd!(A::SkewHermitian{<:Real})
n=size(A,1)
E=eigen!(A)
U=E.vectors
Expand All @@ -138,5 +173,11 @@ LA.svdvals(A::SkewHermitian) = svdvals!(copyeigtype(A))
end
return LA.SVD(U,vals,adjoint(V))
end
@views function LA.svd(A::SkewHermitian{T}) where {T<:Complex}
H=Hermitian(A.data.*1im)
Svd=svd(H)
skew_Svd=SVD(Svd.U,Svd.S,(Svd.Vt).*(-1im))
return skew_Svd
end

LA.svd(A::SkewHermitian) = svd!(copyeigtype(A))
LA.svd(A::SkewHermitian{<:Real}) = svd!(copyeigtype(A))
1 change: 0 additions & 1 deletion src/pfaffian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,3 @@ function pfaffian!(A::AbstractMatrix{<:Real})
isskewhermitian(A) || throw(ArgumentError("Pfaffian requires a skew-Hermitian matrix"))
return _pfaffian!(SkewHermitian(A))
end

11 changes: 7 additions & 4 deletions src/skewhermitian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,11 @@ function LA.dot(A::SkewHermitian, B::SkewHermitian)
throw(DimensionMismatch("A has size $(size(A)) but B has size $(size(B))"))
end
dotprod = zero(dot(first(A), first(B)))
@inbounds for j = 1:n, i = 1:j-1
dotprod += 2 * real(dot(A.data[i, j], B.data[i, j]))
@inbounds for j = 1:n
for i = 1:j-1
dotprod += 2 * dot(A.data[i, j], B.data[i, j])
end
dotprod += dot(A.data[j, j], B.data[j, j])
end
return dotprod
end
Expand Down Expand Up @@ -212,9 +215,9 @@ end
LA.kron(A::SkewHermitian,B::StridedMatrix) = kron(A.data,B)
LA.kron(A::StridedMatrix,B::SkewHermitian) = kron(A,B.data)

@views function LA.schur!(A::SkewHermitian)
@views function LA.schur!(A::SkewHermitian{<:Real})
F=eigen!(A)
return Schur(typeof(F.vectors)(Diagonal(F.values)), F.vectors, F.values)

end
LA.schur(A::SkewHermitian)= LA.schur!(copy(A))
LA.schur(A::SkewHermitian{<:Real})= LA.schur!(copyeigtype(A))
70 changes: 50 additions & 20 deletions src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,34 +200,64 @@ function Base.:-(A::SkewHermTridiagonal)
end
end

function Base.:*(A::SkewHermTridiagonal, B::Number)
function Base.:*(A::SkewHermTridiagonal, B::T) where {T<:Real}
if A.dvim !== nothing
return SkewHermTridiagonal(A.ev*B,A.dvim*B)
else
return SkewHermTridiagonal(A.ev*B)
end
end
function Base.:*(B::Number,A::SkewHermTridiagonal)
function Base.:*(B::T,A::SkewHermTridiagonal) where {T<:Real}
if A.dvim !== nothing
return SkewHermTridiagonal(B*A.ev,B*A.dvim)
else
return SkewHermTridiagonal(B*A.ev)
end
end
function Base.:/(A::SkewHermTridiagonal, B::Number)
function Base.:*(A::SkewHermTridiagonal, B::T) where {T<:Complex}
if A.dvim !== nothing
return LA.Tridiagonal(A.ev*B,A.dvim*B,-A.ev*B)
else
return LA.Tridiagonal(A.ev*B,zeros(eltype(A.ev)),-A.ev*B)
end
end
function Base.:*(B::T,A::SkewHermTridiagonal) where {T<:Complex}
if A.dvim !== nothing
return LA.Tridiagonal(B*A.ev,B*A.dvim,-B*A.ev)
else
return LA.Tridiagonal(B*A.ev,zeros(eltype(A.ev)),-B*A.ev)
end
end

function Base.:/(A::SkewHermTridiagonal, B::T) where {T<:Real}
if A.dvim !== nothing
return SkewHermTridiagonal(A.ev/B,A.dvim/B)
else
return SkewHermTridiagonal(A.ev/B)
end
end
function Base.:\(B::Number,A::SkewHermTridiagonal)
function Base.:/(A::SkewHermTridiagonal, B::T) where {T<:Complex}
if A.dvim !== nothing
return LA.Tridiagonal(A.ev/B,A.dvim/B,-A.ev/B)
else
return LA.Tridiagonal(A.ev/B,zeros(eltype(A.ev)),-A.ev/B)
end
end
function Base.:\(B::T,A::SkewHermTridiagonal) where {T<:Real}
if A.dvim !== nothing
return SkewHermTridiagonal(B\A.ev,B \A.dvim)
else
return SkewHermTridiagonal(B\A.ev)
end
end
function Base.:\(B::T,A::SkewHermTridiagonal) where {T<:Complex}
if A.dvim !== nothing
return LA.Tridiagonal(B\A.ev,B\A.dvim,-B\A.ev)
else
return LA.Tridiagonal(B\A.ev,zeros(eltype(A.ev)),-B\A.ev)
end
end


# ==(A::SkewHermTridiagonal, B::SkewHermTridiagonal) = (A.ev==B.ev)

Expand Down Expand Up @@ -327,53 +357,53 @@ end

#Base.:\(T::SkewHermTridiagonal, B::StridedVecOrMat) = Base.ldlt(T)\B

@views function LA.eigvals!(A::SkewHermTridiagonal{T,V,Vim}, sortby::Union{Function,Nothing}=nothing) where {T<:Real,V,Vim<:Nothing}
@views function LA.eigvals!(A::SkewHermTridiagonal{T,V,Vim}, sortby::Union{Function,Nothing}=nothing) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}
vals = skeweigvals!(A)
!isnothing(sortby) && sort!(vals, by=sortby)
return complex.(0, vals)
end

@views function LA.eigvals!(A::SkewHermTridiagonal{T,V,Vim}, irange::UnitRange) where {T<:Real,V,Vim<:Nothing}
@views function LA.eigvals!(A::SkewHermTridiagonal{T,V,Vim}, irange::UnitRange) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}
vals = skewtrieigvals!(A,irange)
return complex.(0, vals)
end

@views function LA.eigvals!(A::SkewHermTridiagonal{T,V,Vim}, vl::Real,vh::Real) where {T<:Real,V,Vim<:Nothing}
@views function LA.eigvals!(A::SkewHermTridiagonal{T,V,Vim}, vl::Real,vh::Real) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}
vals = skewtrieigvals!(A,-vh,-vl)
return complex.(0, vals)
end

LA.eigvals(A::SkewHermTridiagonal{T,V,Vim}, irange::UnitRange) where {T<:Real,V,Vim<:Nothing} =
LA.eigvals(A::SkewHermTridiagonal{T,V,Vim}, irange::UnitRange) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing} =
LA.eigvals!(copyeigtype(A), irange)
LA.eigvals(A::SkewHermTridiagonal{T,V,Vim}, vl::Real,vh::Real) where {T<:Real,V,Vim<:Nothing}=
LA.eigvals(A::SkewHermTridiagonal{T,V,Vim}, vl::Real,vh::Real) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}=
LA.eigvals!(copyeigtype(A), vl,vh)



@views function skewtrieigvals!(S::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}
@views function skewtrieigvals!(S::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}
n = size(S,1)
H = SymTridiagonal(zeros(eltype(S.ev),n),S.ev)
vals = eigvals!(H)
return vals .= .-vals

end

@views function skewtrieigvals!(S::SkewHermTridiagonal{T,V,Vim},irange::UnitRange) where {T<:Real,V,Vim<:Nothing}
@views function skewtrieigvals!(S::SkewHermTridiagonal{T,V,Vim},irange::UnitRange) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}
n = size(S,1)
H = SymTridiagonal(zeros(eltype(S.ev),n),S.ev)
vals = eigvals!(H,irange)
return vals .= .-vals

end

@views function skewtrieigvals!(S::SkewHermTridiagonal{T,V,Vim},vl::Real,vh::Real) where {T<:Real,V,Vim<:Nothing}
@views function skewtrieigvals!(S::SkewHermTridiagonal{T,V,Vim},vl::Real,vh::Real) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}
n = size(S,1)
H = SymTridiagonal(zeros(eltype(S.ev),n),S.ev)
vals = eigvals!(H,vl,vh)
return vals .= .-vals
end

@views function skewtrieigen!(S::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}
@views function skewtrieigen!(S::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}

n = size(S,1)
H = SymTridiagonal(zeros(T,n),S.ev)
Expand Down Expand Up @@ -401,7 +431,7 @@ end
end


@views function LA.eigen!(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}
@views function LA.eigen!(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}
return skewtrieigen!(A)
end

Expand All @@ -411,20 +441,20 @@ function copyeigtype(A::SkewHermTridiagonal)
return B
end

LA.eigen(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}=LA.eigen!(copyeigtype(A))
LA.eigen(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}=LA.eigen!(copyeigtype(A))

LA.eigvecs(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}= eigen(A).vectors
LA.eigvecs(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}= eigen(A).vectors


@views function LA.svdvals!(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}
@views function LA.svdvals!(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}
n=size(A,1)
vals = skewtrieigvals!(A)
vals .= abs.(vals)
return sort!(vals; rev=true)
end
LA.svdvals(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}=svdvals!(A)
LA.svdvals(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}=svdvals!(copyeigtype(A))

@views function LA.svd!(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}
@views function LA.svd!(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}
n=size(A,1)
E=eigen!(A)
U=E.vectors
Expand All @@ -445,7 +475,7 @@ LA.svdvals(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}=svdva
end


LA.svd(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V,Vim<:Nothing}= svd!(copyeigtype(A))
LA.svd(A::SkewHermTridiagonal{T,V,Vim}) where {T<:Real,V<:AbstractVector{T},Vim<:Nothing}= svd!(copyeigtype(A))


###################
Expand Down
Loading

0 comments on commit 2820ca4

Please sign in to comment.