Skip to content

Commit

Permalink
added solve_schrodinger_gpu function to closed_system_solvers
Browse files Browse the repository at this point in the history
This is same function as solve_schrodinger but with GPU support. Future commits
will include a better integration, but for now, we wanted to post the working
code before worrying about that.
  • Loading branch information
naezzell committed Dec 8, 2020
1 parent 169cb80 commit f0e34ec
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/OpenQuantumTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include("plot_util/bath_plot.jl")

export solve_unitary,
solve_schrodinger,
solve_schrodinger_gpu,
solve_von_neumann,
solve_redfield,
solve_ame,
Expand Down
16 changes: 16 additions & 0 deletions src/QSolver/closed_system_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ function solve_schrodinger(A::Annealing, tf::Real; tspan = (0, tf), kwargs...)
solve(prob; alg_hints = [:nonstiff], kwargs...)
end

function solve_schrodinger_gpu(A::Annealing, tf::Real; tspan = (0, tf), kwargs...)
u0 = cu(build_u0(A.u0, :v))
p = ODEParams(A.H, float(tf), A.annealing_parameter)
update_func = function (C, u, p, t)
update_cache!(C, p.L, p, p(t))
end
cache = cu(get_cache(A.H))
diff_op = DiffEqArrayOperator(cache, update_func = update_func)
jac_cache = cu(similar(cache))
jac_op = DiffEqArrayOperator(jac_cache, update_func = update_func)
ff = ODEFunction(diff_op, jac_prototype = jac_op)

prob = ODEProblem{true}(ff, u0, Float32.(tspan), p)
solve(prob; alg_hints = [:nonstiff], kwargs...)
end

"""
$(SIGNATURES)
Expand Down

0 comments on commit f0e34ec

Please sign in to comment.