Skip to content

Commit

Permalink
Merge pull request #111 from ytdHuang/opt/handle-eltype
Browse files Browse the repository at this point in the history
Optimize `eltype` handling
  • Loading branch information
ytdHuang authored Oct 15, 2024
2 parents 0022be9 + 169eafb commit f5c1d1e
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 70 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HierarchicalEOM"
uuid = "a62dbcb7-80f5-4d31-9a88-8b19fd92b128"
authors = ["Yi-Te Huang <[email protected]>"]
version = "2.2.3"
version = "2.2.4"

[deps]
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand Down Expand Up @@ -35,7 +35,7 @@ LinearSolve = "2.4.2 - 2"
OrdinaryDiffEqCore = "1"
OrdinaryDiffEqLowOrderRK = "1"
Pkg = "1"
QuantumToolbox = "0.13 - 0.18"
QuantumToolbox = "0.15 - 0.18"
Reexport = "1"
SciMLBase = "2"
SciMLOperators = "0.3"
Expand Down
21 changes: 5 additions & 16 deletions ext/HierarchicalEOM_CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module HierarchicalEOM_CUDAExt

using HierarchicalEOM
import HierarchicalEOM.HeomBase: AbstractHEOMLSMatrix, _Tr, _HandleVectorType, _HandleIdentityType
import HierarchicalEOM.HeomBase: AbstractHEOMLSMatrix, _Tr, _HandleVectorType
import QuantumToolbox: _CType
import CUDA
import CUDA: cu, CuArray
import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC
Expand Down Expand Up @@ -73,14 +74,9 @@ function _Tr(M::AbstractHEOMLSMatrix{T}) where {T<:CuSparseMatrixCSC}
return CuSparseVector(SparseVector(M.N * D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(eltype(M), D)))
end

# for changing a `CuArray` back to `ADOs`
_HandleVectorType(V::T, cp::Bool = false) where {T<:CuArray} = Vector{ComplexF64}(V)

# for changing the type of `ADOs` to match the type of HEOMLS matrix
function _HandleVectorType(MatrixType::Type{TM}, V::SparseVector) where {TM<:CuSparseMatrixCSC}
TE = eltype(MatrixType)
return CuArray{TE}(V)
end
# change the type of `ADOs` to match the type of HEOMLS matrix
_HandleVectorType(::AbstractHEOMLSMatrix{<:CuSparseMatrixCSC{T}}, V::SparseVector) where {T<:Number} =
CuArray{_CType(T)}(V)

##### We first remove this part because there are errors when solveing steady states using GPU
# function _HandleSteadyStateMatrix(MatrixType::Type{TM}, M::AbstractHEOMLSMatrix, S::Int) where TM <: CuSparseMatrixCSC
Expand All @@ -95,11 +91,4 @@ end
# return CuSparseMatrixCSC(A)
# end

function _HandleIdentityType(MatrixType::Type{TM}, S::Int) where {TM<:CuSparseMatrixCSC}
colptr = CuArray{Int32}(Int32(1):Int32(S + 1))
rowval = CuArray{Int32}(Int32(1):Int32(S))
nzval = CUDA.ones(ComplexF32, S)
return CuSparseMatrixCSC{ComplexF32,Int32}(colptr, rowval, nzval, (S, S))
end

end
36 changes: 6 additions & 30 deletions src/HeomBase.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
abstract type AbstractHEOMLSMatrix{T} end

_FType(::AbstractHEOMLSMatrix{<:AbstractArray{T}}) where {T<:Number} = _FType(T)
_CType(::AbstractHEOMLSMatrix{<:AbstractArray{T}}) where {T<:Number} = _CType(T)

# equal to : sparse(vec(system_identity_matrix))
function _Tr(dims::SVector, N::Int)
D = prod(dims)
Expand Down Expand Up @@ -29,31 +32,9 @@ HandleMatrixType(M, dims::SVector, MatrixName::String = ""; type::QuantumObjectT
HandleMatrixType(M, MatrixName::String = ""; type::QuantumObjectType = Operator) =
error("HierarchicalEOM doesn't support matrix $(MatrixName) with type : $(typeof(M))")

function _HandleFloatType(ElType::Type{T}, V::StepRangeLen) where {T<:Number}
if real(ElType) == Float32
return StepRangeLen(Float32(V.ref), Float32(V.step), Int32(V.len), Int64(V.offset))
else
return StepRangeLen(Float64(V.ref), Float64(V.step), Int64(V.len), Int64(V.offset))
end
end

function _HandleFloatType(ElType::Type{T}, V::Any) where {T<:Number}
FType = real(ElType)
if eltype(V) == FType
return V
else
convert.(FType, V)
end
end

# for changing a `Vector` back to `ADOs`
_HandleVectorType(V::T, cp::Bool = true) where {T<:Vector} = cp ? Vector{ComplexF64}(V) : V

# for changing the type of `ADOs` to match the type of HEOMLS matrix
function _HandleVectorType(MatrixType::Type{TM}, V::SparseVector) where {TM<:AbstractMatrix}
TE = eltype(MatrixType)
return Vector{TE}(V)
end
# change the type of `ADOs` to match the type of HEOMLS matrix
_HandleVectorType(::AbstractHEOMLSMatrix{<:AbstractSparseMatrix{T}}, V::SparseVector) where {T<:Number} =
Vector{_CType(T)}(V)

function _HandleSteadyStateMatrix(M::AbstractHEOMLSMatrix, S::Int)
ElType = eltype(M)
Expand All @@ -66,11 +47,6 @@ function _HandleSteadyStateMatrix(M::AbstractHEOMLSMatrix, S::Int)
return A
end

function _HandleIdentityType(MatrixType::Type{TM}, S::Int) where {TM<:AbstractMatrix}
ElType = eltype(MatrixType)
return sparse(one(ElType) * I, S, S)
end

function _check_sys_dim_and_ADOs_num(A, B)
if (A.dims != B.dims)
error("Inconsistent system dimension (\"dims\").")
Expand Down
9 changes: 4 additions & 5 deletions src/HierarchicalEOM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@ import Reexport: @reexport
module HeomBase
import Pkg
import LinearAlgebra: BLAS, kron, I
import SparseArrays: sparse, SparseVector, SparseMatrixCSC
import SparseArrays: sparse, SparseVector, SparseMatrixCSC, AbstractSparseMatrix
import StaticArraysCore: SVector
import QuantumToolbox: QuantumObject, QuantumObjectType, Operator, SuperOperator
import QuantumToolbox: _FType, _CType, QuantumObject, QuantumObjectType, Operator, SuperOperator

export _Tr,
AbstractHEOMLSMatrix,
_check_sys_dim_and_ADOs_num,
_check_parity,
HandleMatrixType,
_HandleFloatType,
_HandleVectorType,
_HandleSteadyStateMatrix,
_HandleIdentityType
_HandleSteadyStateMatrix

include("HeomBase.jl")
end
Expand Down Expand Up @@ -89,6 +87,7 @@ module HeomAPI
import SparseArrays: sparse, sparsevec, spzeros, SparseVector, SparseMatrixCSC, AbstractSparseMatrix
import StaticArraysCore: SVector
import QuantumToolbox:
_FType,
QuantumObject,
Operator,
SuperOperator,
Expand Down
9 changes: 4 additions & 5 deletions src/density_of_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ Calculate density of states for the fermionic system in frequency domain.
filename::String = "",
SOLVEROptions...,
)
Size = size(M, 1)

# check M
if M.parity == EVEN
Expand All @@ -56,8 +55,8 @@ Calculate density of states for the fermionic system in frequency domain.
Id_cache = I(M.N)
d_normal = HEOMSuperOp(d_op, ODD, M; Id_cache = Id_cache)
d_dagger = HEOMSuperOp(d_op', ODD, M; Id_cache = Id_cache)
b_m = _HandleVectorType(typeof(M.data), (d_normal * ados).data)
b_p = _HandleVectorType(typeof(M.data), (d_dagger * ados).data)
b_m = _HandleVectorType(M, (d_normal * ados).data)
b_p = _HandleVectorType(M, (d_dagger * ados).data)
_tr_d_normal = _tr * MType(d_normal).data
_tr_d_dagger = _tr * MType(d_dagger).data

Expand All @@ -68,7 +67,7 @@ Calculate density of states for the fermionic system in frequency domain.
end

ElType = eltype(M)
ωList = _HandleFloatType(ElType, ωlist)
ωList = convert(Vector{_FType(M)}, ωlist) # Convert it to support GPUs and avoid type instabilities
Length = length(ωList)
= Vector{Float64}(undef, Length)

Expand All @@ -78,7 +77,7 @@ Calculate density of states for the fermionic system in frequency domain.
end
prog = ProgressBar(Length; enable = verbose)
i = convert(ElType, 1im)
I_total = _HandleIdentityType(typeof(M.data), Size)
I_total = I(size(M, 1))
cache_m = cache_p = nothing
for ω in ωList
= i * ω * I_total
Expand Down
8 changes: 4 additions & 4 deletions src/evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function HEOMsolve(
ados = (T_state <: QuantumObject) ? ADOs(ρ0, M.N, M.parity) : ρ0
_check_sys_dim_and_ADOs_num(M, ados)
_check_parity(M, ados)
ρvec = _HandleVectorType(typeof(M.data), ados.data)
ρvec = _HandleVectorType(M, ados.data)

if e_ops isa Nothing
expvals = Array{ComplexF64}(undef, 0, steps + 1)
Expand Down Expand Up @@ -203,9 +203,9 @@ function HEOMsolve(
ados = (T_state <: QuantumObject) ? ADOs(ρ0, M.N, M.parity) : ρ0
_check_sys_dim_and_ADOs_num(M, ados)
_check_parity(M, ados)
u0 = _HandleVectorType(typeof(M.data), ados.data)
u0 = _HandleVectorType(M, ados.data)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
t_l = convert(Vector{_FType(M)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

# handle e_ops
Id_cache = I(M.N)
Expand Down Expand Up @@ -254,7 +254,7 @@ function HEOMsolve(
flush(stdout)
end
sol = solve(prob, solver)
ADOs_list = map(ρvec -> ADOs(_HandleVectorType(ρvec, false), M.dims, M.N, M.parity), sol.u)
ADOs_list = map(ρvec -> ADOs(Vector{ComplexF64}(ρvec), M.dims, M.N, M.parity), sol.u)

# save ADOs to file
if filename != ""
Expand Down
7 changes: 3 additions & 4 deletions src/power_spectrum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ remember to set the parameters:
filename::String = "",
SOLVEROptions...,
)
Size = size(M, 1)

# Handle ρ
if typeof(ρ) == ADOs # ρ::ADOs
Expand Down Expand Up @@ -109,7 +108,7 @@ remember to set the parameters:
end
_Q_ados = _Q * ados
end
b = _HandleVectorType(typeof(M.data), _Q_ados.data)
b = _HandleVectorType(M, _Q_ados.data)

SAVE::Bool = (filename != "")
if SAVE
Expand All @@ -118,7 +117,7 @@ remember to set the parameters:
end

ElType = eltype(M)
ωList = _HandleFloatType(ElType, ωlist)
ωList = convert(Vector{_FType(M)}, ωlist) # Convert it to support GPUs and avoid type instabilities
Length = length(ωList)
= Vector{Float64}(undef, Length)

Expand All @@ -128,7 +127,7 @@ remember to set the parameters:
end
prog = ProgressBar(Length; enable = verbose)
i = reverse ? convert(ElType, 1im) : i = convert(ElType, -1im)
I_total = _HandleIdentityType(typeof(M.data), Size)
I_total = I(size(M, 1))
cache = nothing
for ω in ωList
= i * ω * I_total
Expand Down
8 changes: 4 additions & 4 deletions src/steadystate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ function steadystate(
print("Solving steady state for ADOs by linear-solve method...")
flush(stdout)
end
cache = init(LinearProblem(A, _HandleVectorType(typeof(M.data), b)), solver, SOLVEROptions...)
cache = init(LinearProblem(A, _HandleVectorType(M, b)), solver, SOLVEROptions...)
sol = solve!(cache)
if verbose
println("[DONE]")
flush(stdout)
end

return ADOs(_HandleVectorType(sol.u, false), M.dims, M.N, M.parity)
return ADOs(Vector{ComplexF64}(sol.u), M.dims, M.N, M.parity)
end

@doc raw"""
Expand Down Expand Up @@ -79,7 +79,7 @@ function steadystate(
ados = (T_state <: QuantumObject) ? ADOs(ρ0, M.N, M.parity) : ρ0
_check_sys_dim_and_ADOs_num(M, ados)
_check_parity(M, ados)
u0 = _HandleVectorType(typeof(M.data), ados.data)
u0 = _HandleVectorType(M, ados.data)

Tspan = (0, tspan)

Expand All @@ -104,7 +104,7 @@ function steadystate(
flush(stdout)
end

return ADOs(_HandleVectorType(sol.u[end], false), M.dims, M.N, M.parity)
return ADOs(Vector{ComplexF64}(sol.u[end]), M.dims, M.N, M.parity)
end

function _ss_condition(integrator, abstol, reltol, min_t)
Expand Down

0 comments on commit f5c1d1e

Please sign in to comment.