From e08608bab3cb73aa1c494015b155d5b48c11a18f Mon Sep 17 00:00:00 2001 From: Stefan Krastanov Date: Sat, 3 Apr 2021 17:33:36 -0400 Subject: [PATCH 1/2] Implement the OrdinaryDiffEq interface for Kets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The following now works ```julia using QuantumOptics using DifferentialEquations ℋ = SpinBasis(1//2) σx = sigmax(ℋ) ↓ = s = spindown(ℋ) schrod(ψ,p,t) = im * σx * ψ t₀, t₁ = (0.0, pi) Δt = 0.1 prob = ODEProblem(schrod, ↓, (t₀, t₁)) sol = solve(prob,Tsit5()) ``` It works for Bras as well. It works for in-place operations and in some situations it is faster than the standard `timeevolution.schroedinger`. ```julia ℋ = SpinBasis(20//1) ↓ = spindown(ℋ) t₀, t₁ = (0.0, pi) const σx = sigmax(ℋ) const iσx = im * σx schrod!(dψ,ψ,p,t) = mul!(dψ, iσx, ψ) prob! = ODEProblem(schrod!, ↓, (t₀, t₁)) julia> @benchmark sol = solve($prob!,DP5(),save_everystep=false) BenchmarkTools.Trial: memory estimate: 22.67 KiB allocs estimate: 178 -------------- minimum time: 374.463 μs (0.00% GC) median time: 397.327 μs (0.00% GC) mean time: 406.738 μs (0.37% GC) maximum time: 4.386 ms (89.76% GC) -------------- samples: 10000 evals/sample: 1 julia> @benchmark timeevolution.schroedinger([$t₀,$t₁], $↓, $σx) BenchmarkTools.Trial: memory estimate: 23.34 KiB allocs estimate: 161 -------------- minimum time: 748.106 μs (0.00% GC) median time: 774.601 μs (0.00% GC) mean time: 786.933 μs (0.14% GC) maximum time: 4.459 ms (80.46% GC) -------------- samples: 6350 evals/sample: 1 ``` --- Project.toml | 2 + src/operators_dense.jl | 2 +- src/states.jl | 86 ++++++++++++++++++++++-------------------- src/superoperators.jl | 2 +- test/test_states.jl | 15 +++++++- 5 files changed, 62 insertions(+), 45 deletions(-) diff --git a/Project.toml b/Project.toml index 0a1a7b3a..818c0653 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "v0.2.7" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -13,6 +14,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" julia = "1.3" FFTW = "1.2" Adapt = "1, 2" +RecursiveArrayTools = "2.11" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/operators_dense.jl b/src/operators_dense.jl index 80a6bee2..de077558 100644 --- a/src/operators_dense.jl +++ b/src/operators_dense.jl @@ -340,7 +340,7 @@ Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where { end find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)} +const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes) args_ = Tuple(a.data for a=args) return Broadcast.Broadcasted(f, args_, axes) diff --git a/src/states.jl b/src/states.jl index df604f95..7f74561a 100644 --- a/src/states.jl +++ b/src/states.jl @@ -209,50 +209,49 @@ Broadcast.BroadcastStyle(::Type{<:Bra{B}}) where {B<:Basis} = BraStyle{B}() Broadcast.BroadcastStyle(::KetStyle{B1}, ::KetStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(IncompatibleBases()) Broadcast.BroadcastStyle(::BraStyle{B1}, ::BraStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(IncompatibleBases()) +# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`) +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:KetStyle{B}} = T() +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:BraStyle{B}} = T() + # Out-of-place broadcasting @inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) b = find_basis(bcf) - return Ket{B}(b, copy(bc_)) + T = find_dType(bcf) + data = zeros(T, length(b)) + @inbounds @simd for I in eachindex(bcf) + data[I] = bcf[I] + end + return Ket{B}(b, data) end @inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) b = find_basis(bcf) - return Bra{B}(b, copy(bc_)) -end -find_basis(bc::Broadcast.Broadcasted) = find_basis(bc.args) -find_basis(args::Tuple) = find_basis(find_basis(args[1]), Base.tail(args)) -find_basis(x) = x -find_basis(a::StateVector, rest) = a.basis -find_basis(::Any, rest) = find_basis(rest) - -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)} -function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector - args_ = Tuple(a.data for a=args) - return Broadcast.Broadcasted(f, args_, axes) + T = find_dType(bcf) + data = zeros(T, length(b)) + @inbounds @simd for I in eachindex(bcf) + data[I] = bcf[I] + end + return Bra{B}(b, data) end -function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector - throw(error("Cannot broadcast function `$f` on type `$T`")) +for f ∈ [:find_basis,:find_dType] + @eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args) + @eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args)) + @eval ($f)(x) = x + @eval ($f)(::Any, rest) = ($f)(rest) end - +find_basis(a::StateVector, rest) = a.basis +find_dType(a::StateVector, rest) = eltype(a) +Base.getindex(st::StateVector, idx) = getindex(st.data, idx) # In-place broadcasting for Kets @inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args} - axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) - # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match - if bc.f === identity && isa(bc.args, Tuple{<:Ket{B}}) # only a single input argument to broadcast! - A = bc.args[1] - if axes(dest) == axes(A) - return copyto!(dest, A) - end + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.data + @inbounds @simd for I in eachindex(bc′) + dest′[I] = bc′[I] end - # Get the underlying data fields of kets and broadcast them as arrays - bcf = Broadcast.flatten(bc) - args_ = Tuple(a.data for a=bcf.args) - bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) - copyto!(dest.data, bc_) return dest end @inline Base.copyto!(dest::Ket{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:KetStyle{B2},Axes,F,Args} = @@ -260,21 +259,26 @@ end # In-place broadcasting for Bras @inline function Base.copyto!(dest::Bra{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args} - axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) - # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match - if bc.f === identity && isa(bc.args, Tuple{<:Bra{B}}) # only a single input argument to broadcast! - A = bc.args[1] - if axes(dest) == axes(A) - return copyto!(dest, A) - end + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.data + @inbounds @simd for I in eachindex(bc′) + dest′[I] = bc′[I] end - # Get the underlying data fields of bras and broadcast them as arrays - bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - copyto!(dest.data, bc_) return dest end @inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:BraStyle{B2},Axes,F,Args} = throw(IncompatibleBases()) @inline Base.copyto!(A::T,B::T) where T<:StateVector = (copyto!(A.data,B.data); A) + +# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl +Base.eltype(::Type{Ket{B,A}}) where {B,N,A<:AbstractVector{N}} = N # ODE init +Base.eltype(::Type{Bra{B,A}}) where {B,N,A<:AbstractVector{N}} = N +Base.zero(k::StateVector) = typeof(k)(k.basis, zero(k.data)) # ODE init +Base.any(f::Function, x::StateVector; kwargs...) = any(f, x.data; kwargs...) # ODE nan checks +Base.all(f::Function, x::StateVector; kwargs...) = all(f, x.data; kwargs...) +Broadcast.similar(k::StateVector, t) = typeof(k)(k.basis, copy(k.data)) +using RecursiveArrayTools +RecursiveArrayTools.recursivecopy!(dst::Ket{B,A},src::Ket{B,A}) where {B,A} = copy!(dst.data,src.data) # ODE in-place equations +RecursiveArrayTools.recursivecopy!(dst::Bra{B,A},src::Bra{B,A}) where {B,A} = copy!(dst.data,src.data) \ No newline at end of file diff --git a/src/superoperators.jl b/src/superoperators.jl index 0ef51fac..4518b527 100644 --- a/src/superoperators.jl +++ b/src/superoperators.jl @@ -232,7 +232,7 @@ end # end find_basis(a::SuperOperator, rest) = (a.basis_l, a.basis_r) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)} +const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:SuperOperator}}, axes) args_ = Tuple(a.data for a=args) return Broadcast.Broadcasted(f, args_, axes) diff --git a/test/test_states.jl b/test/test_states.jl index 05334d9b..874b4149 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -161,7 +161,18 @@ psi_ .+= psi123 bra_ = copy(bra123) bra_ .= 3*bra123 @test bra_ == 3*dagger(psi123) -@test_throws ErrorException cos.(psi_) -@test_throws ErrorException cos.(bra_) +@test bra_ .* 2 == bra_ .+ bra_ +@test bra_ * 2 == bra_ .+ bra_ +z = zero(bra_) +z .= bra_ .* 2 +@test_broken all(z .== bra_ .+ bra_) +@test z == bra_ .+ bra_ +ket_ = bra_' +@test ket_ .* 2 == ket_ .+ ket_ +@test ket_ * 2 == ket_ .+ ket_ +z = zero(ket_) +z .= ket_ .* 2 +@test_broken all(z .== ket_.+ ket_) +@test z == ket_ .+ ket_ end # testset From 4ca650085767559a7a26189e9f56657b5303523b Mon Sep 17 00:00:00 2001 From: Stefan Krastanov Date: Tue, 6 Apr 2021 21:58:43 -0400 Subject: [PATCH 2/2] Implement OrdinaryDiffEq interface for dense operators. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The following works now: ```julia ℋ = SpinBasis(20//1) const σx = sigmax(ℋ) const iσx = im * σx const σ₋ = sigmam(ℋ) const σ₊ = σ₋' const mhalfσ₊σ₋ = -σ₊*σ₋/2 ↓ = spindown(ℋ) ρ = dm(↓) lind(ρ,p,t) = - iσx * ρ + ρ * iσx + σ₋*ρ*σ₊ + mhalfσ₊σ₋ * ρ + ρ * mhalfσ₊σ₋ t₀, t₁ = (0.0, pi) Δt = 0.1 prob = ODEProblem(lind, ρ, (t₀, t₁)) sol = solve(prob,Tsit5()) ``` Works in-place as well. It is slightly slower than `timeevolution.master`: ```julia function makelind!() tmp = zero(ρ) # this is the global rho function lind!(dρ,ρ,p,t) # TODO this can be much better with a good Tullio kernel mul!(tmp, ρ, σ₊) mul!(dρ, σ₋, ρ) mul!(dρ, ρ, mhalfσ₊σ₋, true, true) mul!(dρ, mhalfσ₊σ₋, ρ, true, true) mul!(dρ, iσx, ρ, -ComplexF64(1), ComplexF64(1)) mul!(dρ, ρ, iσx, true, true) return dρ end end lind! = makelind!() prob! = ODEProblem(lind!, ρ, (t₀, t₁)) julia> @benchmark sol = solve($prob!,DP5(),save_everystep=false) BenchmarkTools.Trial: memory estimate: 408.94 KiB allocs estimate: 213 -------------- minimum time: 126.334 ms (0.00% GC) median time: 127.359 ms (0.00% GC) mean time: 127.876 ms (0.00% GC) maximum time: 138.660 ms (0.00% GC) -------------- samples: 40 evals/sample: 1 julia> @benchmark timeevolution.master([$t₀,$t₁], $ρ, $σx, [$σ₋]) BenchmarkTools.Trial: memory estimate: 497.91 KiB allocs estimate: 210 -------------- minimum time: 97.902 ms (0.00% GC) median time: 98.469 ms (0.00% GC) mean time: 98.655 ms (0.00% GC) maximum time: 104.850 ms (0.00% GC) -------------- samples: 51 evals/sample: 1 ``` --- src/operators_dense.jl | 47 +++++++++++++++++++---------------- test/test_abstractdata.jl | 1 - test/test_operators_dense.jl | 6 ++++- test/test_operators_sparse.jl | 1 - 4 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/operators_dense.jl b/src/operators_dense.jl index de077558..d68fa073 100644 --- a/src/operators_dense.jl +++ b/src/operators_dense.jl @@ -331,40 +331,45 @@ struct OperatorStyle{BL<:Basis,BR<:Basis} <: DataOperatorStyle{BL,BR} end Broadcast.BroadcastStyle(::Type{<:Operator{BL,BR}}) where {BL<:Basis,BR<:Basis} = OperatorStyle{BL,BR}() Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {B1<:Basis,B2<:Basis,B3<:Basis,B4<:Basis} = throw(IncompatibleBases()) +# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`) +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {Bl<:Basis, Br<:Basis, T<:OperatorStyle{Bl,Br}} = T() + # Out-of-place broadcasting @inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:OperatorStyle{BL,BR},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) bl,br = find_basis(bcf.args) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - return Operator{BL,BR}(bl, br, copy(bc_)) + T = find_dType(bcf) + data = zeros(T, length(bl), length(br)) + @inbounds @simd for I in eachindex(bcf) + data[I] = bcf[I] + end + return Operator{BL,BR}(bl, br, data) end -find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} -function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes) - args_ = Tuple(a.data for a=args) - return Broadcast.Broadcasted(f, args_, axes) -end -function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:DataOperator}}, axes) - throw(error("Cannot broadcast function `$f` on type `$(eltype(args))`")) -end +find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) +find_dType(a::DataOperator, rest) = eltype(a) +Base.getindex(a::DataOperator, idx) = getindex(a.data, idx) +Base.iterate(a::DataOperator) = iterate(a.data) +Base.iterate(a::DataOperator, idx) = iterate(a.data, idx) # In-place broadcasting @inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle{BL,BR},Axes,F,Args} axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) - # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match - if bc.f === identity && isa(bc.args, Tuple{<:DataOperator{BL,BR}}) # only a single input argument to broadcast! - A = bc.args[1] - if axes(dest) == axes(A) - return copyto!(dest, A) - end + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.data + @inbounds @simd for I in eachindex(bc′) + dest′[I] = bc′[I] end - # Get the underlying data fields of operators and broadcast them as arrays - bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - copyto!(dest.data, bc_) return dest end @inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL<:Basis,BR<:Basis} = (copyto!(A.data,B.data); A) @inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle,Axes,F,Args} = throw(IncompatibleBases()) + +# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl +Base.eltype(::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = N # ODE init +Base.any(f::Function, ρ::Operator; kwargs...) = any(f, ρ.data; kwargs...) # ODE nan checks +Base.all(f::Function, ρ::Operator; kwargs...) = all(f, ρ.data; kwargs...) +Broadcast.similar(ρ::Operator, t) = typeof(ρ)(ρ.basis_l, ρ.basis_r, copy(ρ.data)) +using RecursiveArrayTools +RecursiveArrayTools.recursivecopy!(dst::Operator{Bl,Br,A},src::Operator{Bl,Br,A}) where {Bl,Br,A} = copy!(dst.data,src.data) # ODE in-place equations diff --git a/test/test_abstractdata.jl b/test/test_abstractdata.jl index 5aff3e70..3dd6610f 100644 --- a/test/test_abstractdata.jl +++ b/test/test_abstractdata.jl @@ -340,7 +340,6 @@ op1 .= op1_ .+ 3 * op1_ bf = FockBasis(3) op3 = randtestoperator(bf) @test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3 -@test_throws ErrorException cos.(op1) #################### # Test lazy tensor # diff --git a/test/test_operators_dense.jl b/test/test_operators_dense.jl index f10e4fa8..ce3aeffe 100644 --- a/test/test_operators_dense.jl +++ b/test/test_operators_dense.jl @@ -360,6 +360,10 @@ op1 .= op1_ .+ 3 * op1_ bf = FockBasis(3) op3 = randoperator(bf) @test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3 -@test_throws ErrorException cos.(op1) +@test op3 * 2 == op3 .+ op3 +z = zero(op3) +z .= op3 .* 3 +@test z == op3 .* 2 .+ op3 +@test_broken all(z .== op3 .* 2 .+ op3) end # testset diff --git a/test/test_operators_sparse.jl b/test/test_operators_sparse.jl index 05585ae6..b47c2730 100644 --- a/test/test_operators_sparse.jl +++ b/test/test_operators_sparse.jl @@ -386,6 +386,5 @@ op3 = sprandop(FockBasis(1),FockBasis(2)) op_ = copy(op1) op_ .+= op1 @test op_ == 2*op1 -@test_throws ErrorException cos.(op_) end # testset