Skip to content

Commit

Permalink
Implement OrdinaryDiffEq interface for dense operators.
Browse files Browse the repository at this point in the history
The following works now:

```julia
ℋ = SpinBasis(20//1)

const σx = sigmax(ℋ)
const iσx = im * σx
const σ₋ = sigmam(ℋ)
const σ₊ = σ₋'
const mhalfσ₊σ₋ = -σ₊*σ₋/2

↓ = spindown(ℋ)
ρ = dm(↓)

lind(ρ,p,t) = - iσx * ρ + ρ * iσx + σ₋*ρ*σ₊ + mhalfσ₊σ₋ * ρ + ρ * mhalfσ₊σ₋

t₀, t₁ = (0.0, pi)
Δt = 0.1

prob = ODEProblem(lind, ρ, (t₀, t₁))
sol = solve(prob,Tsit5())
```

Works in-place as well.

It is slightly slower than `timeevolution.master`:

```julia
function makelind!()
    tmp = zero(ρ) # this is the global rho
    function lind!(dρ,ρ,p,t) # TODO this can be much better with a good Tullio kernel
        mul!(tmp, ρ, σ₊)
        mul!(dρ, σ₋, ρ)
        mul!(dρ,    ρ, mhalfσ₊σ₋, true, true)
        mul!(dρ, mhalfσ₊σ₋,    ρ, true, true)
        mul!(dρ,  iσx,    ρ, -ComplexF64(1),   ComplexF64(1))
        mul!(dρ,    ρ,  iσx,  true,   true)
        return dρ
    end
end
lind! = makelind!()
prob! = ODEProblem(lind!, ρ, (t₀, t₁))

julia> @benchmark sol = solve($prob!,DP5(),save_everystep=false)
BenchmarkTools.Trial:
  memory estimate:  408.94 KiB
  allocs estimate:  213
  --------------
  minimum time:     126.334 ms (0.00% GC)
  median time:      127.359 ms (0.00% GC)
  mean time:        127.876 ms (0.00% GC)
  maximum time:     138.660 ms (0.00% GC)
  --------------
  samples:          40
  evals/sample:     1

julia> @benchmark timeevolution.master([$t₀,$t₁], $ρ, $σx, [$σ₋])
BenchmarkTools.Trial:
  memory estimate:  497.91 KiB
  allocs estimate:  210
  --------------
  minimum time:     97.902 ms (0.00% GC)
  median time:      98.469 ms (0.00% GC)
  mean time:        98.655 ms (0.00% GC)
  maximum time:     104.850 ms (0.00% GC)
  --------------
  samples:          51
  evals/sample:     1
```
  • Loading branch information
Krastanov committed Apr 7, 2021
1 parent e08608b commit f536761
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,40 +331,45 @@ struct OperatorStyle{BL<:Basis,BR<:Basis} <: DataOperatorStyle{BL,BR} end
Broadcast.BroadcastStyle(::Type{<:Operator{BL,BR}}) where {BL<:Basis,BR<:Basis} = OperatorStyle{BL,BR}()
Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {B1<:Basis,B2<:Basis,B3<:Basis,B4<:Basis} = throw(IncompatibleBases())

# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`)
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {Bl<:Basis, Br<:Basis, T<:OperatorStyle{Bl,Br}} = T()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:OperatorStyle{BL,BR},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bl,br = find_basis(bcf.args)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
return Operator{BL,BR}(bl, br, copy(bc_))
T = find_dType(bcf)
data = zeros(T, length(bl), length(br))
@inbounds @simd for I in eachindex(bcf)
data[I] = bcf[I]
end
return Operator{BL,BR}(bl, br, data)
end
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes)
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
end
function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:DataOperator}}, axes)
throw(error("Cannot broadcast function `$f` on type `$(eltype(args))`"))
end
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)
find_dType(a::DataOperator, rest) = eltype(a)
Base.getindex(a::DataOperator, idx) = getindex(a.data, idx)
Base.iterate(a::DataOperator) = iterate(a.data)
Base.iterate(a::DataOperator, idx) = iterate(a.data, idx)

# In-place broadcasting
@inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle{BL,BR},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:DataOperator{BL,BR}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.data
@inbounds @simd for I in eachindex(bc′)
dest′[I] = bc′[I]
end
# Get the underlying data fields of operators and broadcast them as arrays
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL<:Basis,BR<:Basis} = (copyto!(A.data,B.data); A)
@inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle,Axes,F,Args} =
throw(IncompatibleBases())

# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
Base.eltype(::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = N # ODE init
Base.any(f::Function, ρ::Operator; kwargs...) = any(f, ρ.data; kwargs...) # ODE nan checks
Base.all(f::Function, ρ::Operator; kwargs...) = all(f, ρ.data; kwargs...)
Broadcast.similar::Operator, t) = typeof(ρ)(ρ.basis_l, ρ.basis_r, copy.data))
using RecursiveArrayTools
RecursiveArrayTools.recursivecopy!(dst::Operator{Bl,Br,A},src::Operator{Bl,Br,A}) where {Bl,Br,A} = copy!(dst.data,src.data) # ODE in-place equations

0 comments on commit f536761

Please sign in to comment.