From f0e34ec2f94357d9bc1c2fa65459646a8c5b3857 Mon Sep 17 00:00:00 2001 From: naezzell Date: Mon, 7 Dec 2020 16:35:39 -0800 Subject: [PATCH] added solve_schrodinger_gpu function to closed_system_solvers This is same function as solve_schrodinger but with GPU support. Future commits will include a better integration, but for now, we wanted to post the working code before worrying about that. --- src/OpenQuantumTools.jl | 1 + src/QSolver/closed_system_solvers.jl | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/OpenQuantumTools.jl b/src/OpenQuantumTools.jl index 4ef8b13..3d01603 100644 --- a/src/OpenQuantumTools.jl +++ b/src/OpenQuantumTools.jl @@ -43,6 +43,7 @@ include("plot_util/bath_plot.jl") export solve_unitary, solve_schrodinger, + solve_schrodinger_gpu, solve_von_neumann, solve_redfield, solve_ame, diff --git a/src/QSolver/closed_system_solvers.jl b/src/QSolver/closed_system_solvers.jl index b81545c..870bcb1 100644 --- a/src/QSolver/closed_system_solvers.jl +++ b/src/QSolver/closed_system_solvers.jl @@ -27,6 +27,22 @@ function solve_schrodinger(A::Annealing, tf::Real; tspan = (0, tf), kwargs...) solve(prob; alg_hints = [:nonstiff], kwargs...) end +function solve_schrodinger_gpu(A::Annealing, tf::Real; tspan = (0, tf), kwargs...) + u0 = cu(build_u0(A.u0, :v)) + p = ODEParams(A.H, float(tf), A.annealing_parameter) + update_func = function (C, u, p, t) + update_cache!(C, p.L, p, p(t)) + end + cache = cu(get_cache(A.H)) + diff_op = DiffEqArrayOperator(cache, update_func = update_func) + jac_cache = cu(similar(cache)) + jac_op = DiffEqArrayOperator(jac_cache, update_func = update_func) + ff = ODEFunction(diff_op, jac_prototype = jac_op) + + prob = ODEProblem{true}(ff, u0, Float32.(tspan), p) + solve(prob; alg_hints = [:nonstiff], kwargs...) +end + """ $(SIGNATURES)