Skip to content

Commit

Permalink
Merge pull request #6 from AmitRotem/pst
Browse files Browse the repository at this point in the history
fix covrage
  • Loading branch information
AmitRotem authored Jan 22, 2023
2 parents a128c33 + 9020a11 commit c9ad058
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions test/test_ForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ import ForwardDiff
# system
ba0 = FockBasis(2)
psi = basisstate(ba0, 1)
target0 = basisstate(ba0, 2)
function getHt(p)
op = [create(ba0)+destroy(ba0)]
f(t) = sin(p*t)
H_at_t = LazySum([f(0)], op)
H_at_t = LazySum([f(0.0)], op)
function Ht(t,_)
H_at_t.factors .= (f(t),)
return H_at_t
Expand All @@ -30,27 +29,24 @@ function getHt(p)
end

# cost function
function cost(par; kwargs...)
function cost(par, ψ0; kwargs...)
opti = (;dtmax=exp2(-4), dt=exp2(-4))
Ht = getHt(par)
# this will rebuild the Bra with Dual elements
_, ψT = timeevolution.schroedinger_dynamic((0.0, 0.2), psi' , Ht; opti..., kwargs...)
# this will not rebuild the Bra
_, ψT = timeevolution.schroedinger_dynamic((0.2, 0.4), last(ψT) , Ht; opti..., kwargs...)
# this will not rebuild the Ket
# also tests static schroedinger
_, ψT = timeevolution.schroedinger((0.4, 0.6), last(ψT)', Ht(1.0, ψT); opti..., kwargs...)
# this will not rebuild the Ket
_, ψT = timeevolution.schroedinger_dynamic((0.6, 0.8), last(ψT)last(ψT)', Ht; opti..., kwargs...)
abs2(target0'*last(ψT)*target0)
# this will rebuild the state with Dual elements
_, ψT = timeevolution.schroedinger_dynamic((0.0, 1.0), ψ0, Ht; opti..., kwargs...)
# this will not rebuild the state
_, ψT = timeevolution.schroedinger((1.0, 2.0), last(ψT), Ht(0.5, ψ0); opti..., kwargs...)
(abs2tr)( ψ0.data' * last(ψT).data ) # getting the data so this will work with also Bra states
end

# setup
p0 = rand()
δp = eps()
# test
finite_diff_derivative = ( cost(p0+δp) - cost(p0) ) / δp
Auto_diff_derivative = ForwardDiff.derivative(cost, p0)
@test isapprox(Auto_diff_derivative, finite_diff_derivative; atol=1e-5)
for u0 = (psi, psi', psipsi') # test all methods of `rebuild`
finite_diff_derivative = ( cost(p0+δp, u0) - cost(p0, u0) ) / δp
Auto_diff_derivative = ForwardDiff.derivative(Base.Fix2(cost, u0), p0)
@test isapprox(Auto_diff_derivative, finite_diff_derivative; atol=1e-5)
end

end # testset

0 comments on commit c9ad058

Please sign in to comment.