diff --git a/test/test_ForwardDiff.jl b/test/test_ForwardDiff.jl index a34f84d7..01d8135c 100644 --- a/test/test_ForwardDiff.jl +++ b/test/test_ForwardDiff.jl @@ -17,11 +17,10 @@ import ForwardDiff # system ba0 = FockBasis(2) psi = basisstate(ba0, 1) -target0 = basisstate(ba0, 2) function getHt(p) op = [create(ba0)+destroy(ba0)] f(t) = sin(p*t) - H_at_t = LazySum([f(0)], op) + H_at_t = LazySum([f(0.0)], op) function Ht(t,_) H_at_t.factors .= (f(t),) return H_at_t @@ -30,27 +29,24 @@ function getHt(p) end # cost function -function cost(par; kwargs...) +function cost(par, ψ0; kwargs...) opti = (;dtmax=exp2(-4), dt=exp2(-4)) Ht = getHt(par) - # this will rebuild the Bra with Dual elements - _, ψT = timeevolution.schroedinger_dynamic((0.0, 0.2), psi' , Ht; opti..., kwargs...) - # this will not rebuild the Bra - _, ψT = timeevolution.schroedinger_dynamic((0.2, 0.4), last(ψT) , Ht; opti..., kwargs...) - # this will not rebuild the Ket - # also tests static schroedinger - _, ψT = timeevolution.schroedinger((0.4, 0.6), last(ψT)', Ht(1.0, ψT); opti..., kwargs...) - # this will not rebuild the Ket - _, ψT = timeevolution.schroedinger_dynamic((0.6, 0.8), last(ψT)⊗last(ψT)', Ht; opti..., kwargs...) - abs2(target0'*last(ψT)*target0) + # this will rebuild the state with Dual elements + _, ψT = timeevolution.schroedinger_dynamic((0.0, 1.0), ψ0, Ht; opti..., kwargs...) + # this will not rebuild the state + _, ψT = timeevolution.schroedinger((1.0, 2.0), last(ψT), Ht(0.5, ψ0); opti..., kwargs...) + (abs2∘tr)( ψ0.data' * last(ψT).data ) # getting the data so this will work with also Bra states end # setup p0 = rand() δp = √eps() # test -finite_diff_derivative = ( cost(p0+δp) - cost(p0) ) / δp -Auto_diff_derivative = ForwardDiff.derivative(cost, p0) -@test isapprox(Auto_diff_derivative, finite_diff_derivative; atol=1e-5) +for u0 = (psi, psi', psi⊗psi') # test all methods of `rebuild` + finite_diff_derivative = ( cost(p0+δp, u0) - cost(p0, u0) ) / δp + Auto_diff_derivative = ForwardDiff.derivative(Base.Fix2(cost, u0), p0) + @test isapprox(Auto_diff_derivative, finite_diff_derivative; atol=1e-5) +end end # testset