Skip to content

Commit

Permalink
[Feature] Add batch dimension to input states for time-dependent solv…
Browse files Browse the repository at this point in the history
…ers (#294)

Add batch dimension to input states for time-dependent solvers. Also
introduce flag in config of time-dependent solvers to choose
sparse/dense tensors for calculations.

Closes #292 and #274
  • Loading branch information
vytautas-a authored Nov 15, 2024
1 parent d93e93e commit a11da37
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 41 deletions.
9 changes: 7 additions & 2 deletions docs/time_dependent.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ def ham_t(t: float) -> Tensor:
t_points = torch.linspace(0, duration, n_steps)
final_state_se = sesolve(ham_t, input_state, t_points, SolverType.DP5_SE).states[-1]

# define jump operator L and solve Lindblad master equation
# define jump operator L
L = IMAT.clone()
for i in range(n_qubits-1):
L = torch.kron(L, XMAT)
final_state_me = mesolve(ham_t, input_state, [L], t_points, SolverType.DP5_ME).states[-1]

# prepare initial density matrix with batch dimension as the last
rho0 = torch.matmul(input_state, input_state.T).unsqueeze(-1)

# solve Lindblad master equation
final_state_me = mesolve(ham_t, rho0, [L], t_points, SolverType.DP5_ME).states[-1]

```
5 changes: 4 additions & 1 deletion pyqtorch/hamiltonians/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __init__(
cache_length: int = 1,
duration: Tensor | str | float | None = None,
steps: int = 100,
solver=SolverType.DP5_SE,
solver: SolverType = SolverType.DP5_SE,
use_sparse: bool = False,
):
"""Initializes the HamiltonianEvolution.
Depending on the generator argument, set the type and set the right generator getter.
Expand All @@ -167,6 +168,7 @@ def __init__(
self.solver_type = solver
self.steps = steps
self.duration = duration
self.use_sparse = use_sparse

if isinstance(duration, (str, float, Tensor)) or duration is None:
self.duration = duration
Expand Down Expand Up @@ -416,6 +418,7 @@ def Ht(t: torch.Tensor) -> torch.Tensor:
torch.flatten(state, start_dim=0, end_dim=-2),
t_grid,
self.solver_type,
options={"use_sparse": self.use_sparse},
)

# Retrieve the last state of shape (2**n_qubits, batch_size)
Expand Down
16 changes: 11 additions & 5 deletions pyqtorch/time_dependent/integrators/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def init_tstep(
"""

sc = self.options.atol + torch.abs(y0) * self.options.rtol
f0 = f0.to_dense() if self.options.use_sparse else f0
d0, d1 = (
hairer_norm(y0 / sc).max().item(),
hairer_norm(f0.to_dense() / sc).max().item(),
hairer_norm(f0 / sc).max().item(),
)

if d0 < 1e-5 or d1 < 1e-5:
Expand All @@ -138,7 +139,8 @@ def init_tstep(

y1 = y0 + h0 * f0
f1 = fun(t0 + h0, y1)
d2 = hairer_norm((f1 - f0).to_dense() / sc).max().item() / h0
diff = (f1 - f0).to_dense() if self.options.use_sparse else f1 - f0
d2 = hairer_norm(diff / sc).max().item() / h0
if d1 <= 1e-15 and d2 <= 1e-15:
h1 = max(1e-6, h0 * 1e-3)
else:
Expand Down Expand Up @@ -183,13 +185,17 @@ def run(self) -> Tensor:

# initialize the ODE routine
t, y, *args = self.init_forward()
n1, n2 = y.shape

# run the ODE routine
result = []
for tnext in self.tsave:
y, *args = self.integrate(t, tnext, y, *args)
result.append(y.mH if n1 == n2 else y.T)
result.append(y)
t = tnext

return torch.cat(result).reshape((-1, n1, n2))
if len(y.shape) == 2:
res = torch.stack(result)
elif len(y.shape) == 3:
res = torch.stack(result).permute(0, 2, 3, 1)

return res
22 changes: 12 additions & 10 deletions pyqtorch/time_dependent/integrators/krylov.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ def integrate(self, t0: float, t1: float, y: Tensor) -> Tensor:
pass

def run(self) -> Tensor:
t = self.t0

# run the Krylov routine
result = []
y = self.y0
for tnext in self.tsave:
y = self.integrate(t, tnext, y)
result.append(y.T)
t = tnext

return torch.cat(result).unsqueeze(-1)
out = []
for i in range(self.y0.shape[1]):
# run the Krylov routine
result = []
y = self.y0[:, i : i + 1]
t = self.t0
for tnext in self.tsave:
y = self.integrate(t, tnext, y)
result.append(y.T)
t = tnext
out.append(torch.cat(result))
return torch.stack(out, dim=2)
26 changes: 15 additions & 11 deletions pyqtorch/time_dependent/mesolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def mesolve(
H: Callable[..., Any],
psi0: Tensor,
rho0: Tensor,
L: list[Tensor],
tsave: list | Tensor,
solver: SolverType,
Expand All @@ -22,7 +22,7 @@ def mesolve(
Args:
H (Callable[[float], Tensor]): time-dependent Hamiltonian of the system
psi0 (Tensor): initial state or density matrix of the system
rho0 (Tensor): initial density matrix of the system
L (list[Tensor]): list of jump operators
tsave (Tensor): tensor containing simulation time instants
solver (SolverType): name of the solver to use
Expand All @@ -34,18 +34,22 @@ def mesolve(
"""
options = options or dict()
L = torch.stack(L)
if psi0.size(-2) == 1:
rho0 = psi0.mH @ psi0
elif psi0.size(-1) == 1:
rho0 = psi0 @ psi0.mH
elif psi0.size(-1) == psi0.size(-2):
rho0 = psi0
else:

# check dimensions of initial state
n = H(0.0).shape[0]
if (
(rho0.shape[0] != rho0.shape[1])
or (rho0.shape[0] != n)
or (len(rho0.shape) != 3)
):
raise ValueError(
"Argument `psi0` must be a ket, bra or density matrix, but has shape"
f" {tuple(psi0.shape)}."
f"Argument `rho0` must be a 3D tensor of shape `({n}, {n}, batch_size)`. "
f"Current shape: {tuple(rho0.shape)}."
)

# permute dimensions to allow batch operations
rho0 = rho0.permute(2, 0, 1)

# instantiate appropriate solver
if solver == SolverType.DP5_ME:
opt = AdaptiveSolverOptions(**options)
Expand Down
5 changes: 3 additions & 2 deletions pyqtorch/time_dependent/methods/dp5.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ def step(

# compute iterated Runge-Kutta values
k = torch.empty(7, *f.shape, dtype=self.options.ctype)
k[0] = f.to_dense()
k[0] = f.to_dense() if self.options.use_sparse else f
for i in range(1, 7):
dy = torch.tensordot(dt * beta[i - 1, :i], k[:i].clone(), dims=([0], [0]))
k[i] = fun(t + dt * alpha[i - 1].item(), y + dy).to_dense()
a = fun(t + dt * alpha[i - 1].item(), y + dy)
k[i] = a.to_dense() if self.options.use_sparse else a

# compute results
f_new = k[-1]
Expand Down
2 changes: 2 additions & 0 deletions pyqtorch/time_dependent/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class AdaptiveSolverOptions:
max_factor: float = 5.0
ctype: dtype = torch.complex128
rtype: dtype = torch.float64
use_sparse: bool = False


@dataclass
Expand All @@ -25,3 +26,4 @@ class KrylovSolverOptions:
max_krylov: int = 80
exp_tolerance: float = 1e-10
norm_tolerance: float = 1e-10
use_sparse: bool = False
8 changes: 8 additions & 0 deletions pyqtorch/time_dependent/sesolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def sesolve(
Returns:
Result: dataclass containing the simulated states at each time moment
"""
# check dimensions of initial state
n = H(0.0).shape[0]
if (psi0.shape[0] != n) or len(psi0.shape) != 2:
raise ValueError(
f"Argument `psi0` must be a 2D tensor of shape `({n}, batch_size)`. Current shape:"
f" {tuple(psi0.shape)}."
)

options = options or dict()
# instantiate appropriate solver
if solver == SolverType.DP5_SE:
Expand Down
4 changes: 3 additions & 1 deletion pyqtorch/time_dependent/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def ode_fun(self, t: float, psi: Tensor) -> Tensor:
with warnings.catch_warnings():
# filter-out UserWarning about "Sparse CSR tensor support is in beta state"
warnings.filterwarnings("ignore", category=UserWarning)
res = -1j * self.H(t) @ psi.to_sparse()
res = (
-1j * self.H(t) @ (psi.to_sparse() if self.options.use_sparse else psi)
)
return res


Expand Down
17 changes: 13 additions & 4 deletions tests/test_analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,15 @@ def test_hamevo_endianness_cnot() -> None:


@pytest.mark.parametrize("duration", [torch.rand(1), "duration"])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("use_sparse", [True, False])
@pytest.mark.parametrize("ode_solver", [SolverType.DP5_SE, SolverType.KRYLOV_SE])
def test_timedependent(
tparam: str,
param_y: float,
duration: float,
batch_size: int,
use_sparse: bool,
n_steps: int,
torch_hamiltonian: Callable,
hamevo_generator: Sequence,
Expand All @@ -306,14 +310,18 @@ def test_timedependent(
ode_solver: SolverType,
) -> None:

psi_start = random_state(2)
psi_start = random_state(2, batch_size)

dur_val = duration if isinstance(duration, torch.Tensor) else torch.rand(1)

# simulate with time-dependent solver
t_points = torch.linspace(0, dur_val[0], n_steps)
psi_solver = pyq.sesolve(
torch_hamiltonian, psi_start.reshape(-1, 1), t_points, ode_solver
torch_hamiltonian,
psi_start.reshape(-1, batch_size),
t_points,
ode_solver,
options={"use_sparse": use_sparse},
).states[-1]

# simulate with HamiltonianEvolution
Expand All @@ -324,14 +332,15 @@ def test_timedependent(
hamiltonian_evolution = pyq.HamiltonianEvolution(
generator=hamevo_generator,
time=tparam,
duration=duration,
duration=dur_val,
steps=n_steps,
solver=ode_solver,
use_sparse=use_sparse,
)
values = {"y": param_y, "duration": dur_val}
psi_hamevo = hamiltonian_evolution(
state=psi_start, values=values, embedding=embedding
).reshape(-1, 1)
).reshape(-1, batch_size)

assert torch.allclose(psi_solver, psi_hamevo, rtol=RTOL, atol=ATOL)

Expand Down
22 changes: 17 additions & 5 deletions tests/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@


@pytest.mark.flaky(max_runs=5)
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("ode_solver", [SolverType.DP5_SE, SolverType.KRYLOV_SE])
def test_sesolve(
duration: float,
batch_size: int,
n_steps: int,
torch_hamiltonian: Callable,
qutip_hamiltonian: Callable,
ode_solver: SolverType,
) -> None:

psi0_qutip = qutip.basis(4, 0)

# simulate with torch-based solver
psi0_torch = torch.tensor(psi0_qutip.full()).to(torch.complex128)
psi0_torch = (
torch.tensor(psi0_qutip.full()).to(torch.complex128).repeat(1, batch_size)
)

t_points = torch.linspace(0, duration, n_steps)
state_torch = sesolve(torch_hamiltonian, psi0_torch, t_points, ode_solver).states[
-1
Expand All @@ -38,15 +42,17 @@ def test_sesolve(
# simulate with qutip solver
t_points = np.linspace(0, duration, n_steps)
result = qutip.sesolve(qutip_hamiltonian, psi0_qutip, t_points)
state_qutip = torch.tensor(result.states[-1].full())
state_qutip = torch.tensor(result.states[-1].full()).repeat(1, batch_size)

assert torch.allclose(state_torch, state_qutip, atol=ATOL)


@pytest.mark.flaky(max_runs=5)
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("ode_solver", [SolverType.DP5_ME])
def test_mesolve(
duration: float,
batch_size: int,
n_steps: int,
torch_hamiltonian: Callable,
qutip_hamiltonian: Callable,
Expand All @@ -58,14 +64,20 @@ def test_mesolve(

# simulate with torch-based solver
psi0_torch = torch.tensor(psi0_qutip.full()).to(torch.complex128)
rho0_torch = (
torch.matmul(psi0_torch, psi0_torch.T).unsqueeze(-1).repeat(1, 1, batch_size)
)

t_points = torch.linspace(0, duration, n_steps)
state_torch = mesolve(
torch_hamiltonian, psi0_torch, jump_op_torch, t_points, ode_solver
torch_hamiltonian, rho0_torch, jump_op_torch, t_points, ode_solver
).states[-1]

# simulate with qutip solver
t_points = np.linspace(0, duration, n_steps)
result = qutip.mesolve(qutip_hamiltonian, psi0_qutip, t_points, jump_op_qutip)
state_qutip = torch.tensor(result.states[-1].full())
state_qutip = (
torch.tensor(result.states[-1].full()).unsqueeze(-1).repeat(1, 1, batch_size)
)

assert torch.allclose(state_torch, state_qutip, atol=ATOL)

0 comments on commit a11da37

Please sign in to comment.