From f5367611eec8a8fe16ee772a1565d4bd90403094 Mon Sep 17 00:00:00 2001 From: Stefan Krastanov Date: Tue, 6 Apr 2021 21:58:43 -0400 Subject: [PATCH] Implement OrdinaryDiffEq interface for dense operators. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The following works now: ```julia ℋ = SpinBasis(20//1) const σx = sigmax(ℋ) const iσx = im * σx const σ₋ = sigmam(ℋ) const σ₊ = σ₋' const mhalfσ₊σ₋ = -σ₊*σ₋/2 ↓ = spindown(ℋ) ρ = dm(↓) lind(ρ,p,t) = - iσx * ρ + ρ * iσx + σ₋*ρ*σ₊ + mhalfσ₊σ₋ * ρ + ρ * mhalfσ₊σ₋ t₀, t₁ = (0.0, pi) Δt = 0.1 prob = ODEProblem(lind, ρ, (t₀, t₁)) sol = solve(prob,Tsit5()) ``` Works in-place as well. It is slightly slower than `timeevolution.master`: ```julia function makelind!() tmp = zero(ρ) # this is the global rho function lind!(dρ,ρ,p,t) # TODO this can be much better with a good Tullio kernel mul!(tmp, ρ, σ₊) mul!(dρ, σ₋, ρ) mul!(dρ, ρ, mhalfσ₊σ₋, true, true) mul!(dρ, mhalfσ₊σ₋, ρ, true, true) mul!(dρ, iσx, ρ, -ComplexF64(1), ComplexF64(1)) mul!(dρ, ρ, iσx, true, true) return dρ end end lind! = makelind!() prob! = ODEProblem(lind!, ρ, (t₀, t₁)) julia> @benchmark sol = solve($prob!,DP5(),save_everystep=false) BenchmarkTools.Trial: memory estimate: 408.94 KiB allocs estimate: 213 -------------- minimum time: 126.334 ms (0.00% GC) median time: 127.359 ms (0.00% GC) mean time: 127.876 ms (0.00% GC) maximum time: 138.660 ms (0.00% GC) -------------- samples: 40 evals/sample: 1 julia> @benchmark timeevolution.master([$t₀,$t₁], $ρ, $σx, [$σ₋]) BenchmarkTools.Trial: memory estimate: 497.91 KiB allocs estimate: 210 -------------- minimum time: 97.902 ms (0.00% GC) median time: 98.469 ms (0.00% GC) mean time: 98.655 ms (0.00% GC) maximum time: 104.850 ms (0.00% GC) -------------- samples: 51 evals/sample: 1 ``` --- src/operators_dense.jl | 47 +++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/operators_dense.jl b/src/operators_dense.jl index de077558..d68fa073 100644 --- a/src/operators_dense.jl +++ b/src/operators_dense.jl @@ -331,40 +331,45 @@ struct OperatorStyle{BL<:Basis,BR<:Basis} <: DataOperatorStyle{BL,BR} end Broadcast.BroadcastStyle(::Type{<:Operator{BL,BR}}) where {BL<:Basis,BR<:Basis} = OperatorStyle{BL,BR}() Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {B1<:Basis,B2<:Basis,B3<:Basis,B4<:Basis} = throw(IncompatibleBases()) +# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`) +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {Bl<:Basis, Br<:Basis, T<:OperatorStyle{Bl,Br}} = T() + # Out-of-place broadcasting @inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:OperatorStyle{BL,BR},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) bl,br = find_basis(bcf.args) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - return Operator{BL,BR}(bl, br, copy(bc_)) + T = find_dType(bcf) + data = zeros(T, length(bl), length(br)) + @inbounds @simd for I in eachindex(bcf) + data[I] = bcf[I] + end + return Operator{BL,BR}(bl, br, data) end -find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} -function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes) - args_ = Tuple(a.data for a=args) - return Broadcast.Broadcasted(f, args_, axes) -end -function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:DataOperator}}, axes) - throw(error("Cannot broadcast function `$f` on type `$(eltype(args))`")) -end +find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) +find_dType(a::DataOperator, rest) = eltype(a) +Base.getindex(a::DataOperator, idx) = getindex(a.data, idx) +Base.iterate(a::DataOperator) = iterate(a.data) +Base.iterate(a::DataOperator, idx) = iterate(a.data, idx) # In-place broadcasting @inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle{BL,BR},Axes,F,Args} axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) - # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match - if bc.f === identity && isa(bc.args, Tuple{<:DataOperator{BL,BR}}) # only a single input argument to broadcast! - A = bc.args[1] - if axes(dest) == axes(A) - return copyto!(dest, A) - end + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.data + @inbounds @simd for I in eachindex(bc′) + dest′[I] = bc′[I] end - # Get the underlying data fields of operators and broadcast them as arrays - bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - copyto!(dest.data, bc_) return dest end @inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL<:Basis,BR<:Basis} = (copyto!(A.data,B.data); A) @inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle,Axes,F,Args} = throw(IncompatibleBases()) + +# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl +Base.eltype(::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = N # ODE init +Base.any(f::Function, ρ::Operator; kwargs...) = any(f, ρ.data; kwargs...) # ODE nan checks +Base.all(f::Function, ρ::Operator; kwargs...) = all(f, ρ.data; kwargs...) +Broadcast.similar(ρ::Operator, t) = typeof(ρ)(ρ.basis_l, ρ.basis_r, copy(ρ.data)) +using RecursiveArrayTools +RecursiveArrayTools.recursivecopy!(dst::Operator{Bl,Br,A},src::Operator{Bl,Br,A}) where {Bl,Br,A} = copy!(dst.data,src.data) # ODE in-place equations