Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add batch dimension to input states for time-dependent solvers #294

Merged
merged 4 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/time_dependent.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ def ham_t(t: float) -> Tensor:
t_points = torch.linspace(0, duration, n_steps)
final_state_se = sesolve(ham_t, input_state, t_points, SolverType.DP5_SE).states[-1]

# define jump operator L and solve Lindblad master equation
# define jump operator L
L = IMAT.clone()
for i in range(n_qubits-1):
L = torch.kron(L, XMAT)
final_state_me = mesolve(ham_t, input_state, [L], t_points, SolverType.DP5_ME).states[-1]

# prepare initial density matrix with batch dimension as the last
rho0 = torch.matmul(input_state, input_state.T).unsqueeze(-1)

# solve Lindblad master equation
final_state_me = mesolve(ham_t, rho0, [L], t_points, SolverType.DP5_ME).states[-1]

```
5 changes: 4 additions & 1 deletion pyqtorch/hamiltonians/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __init__(
cache_length: int = 1,
duration: Tensor | str | float | None = None,
steps: int = 100,
solver=SolverType.DP5_SE,
solver: SolverType = SolverType.DP5_SE,
use_sparse: bool = False,
):
"""Initializes the HamiltonianEvolution.
Depending on the generator argument, set the type and set the right generator getter.
Expand All @@ -167,6 +168,7 @@ def __init__(
self.solver_type = solver
self.steps = steps
self.duration = duration
self.use_sparse = use_sparse

if isinstance(duration, (str, float, Tensor)) or duration is None:
self.duration = duration
Expand Down Expand Up @@ -416,6 +418,7 @@ def Ht(t: torch.Tensor) -> torch.Tensor:
torch.flatten(state, start_dim=0, end_dim=-2),
t_grid,
self.solver_type,
options={"use_sparse": self.use_sparse},
)

# Retrieve the last state of shape (2**n_qubits, batch_size)
Expand Down
16 changes: 11 additions & 5 deletions pyqtorch/time_dependent/integrators/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def init_tstep(
"""

sc = self.options.atol + torch.abs(y0) * self.options.rtol
f0 = f0.to_dense() if self.options.use_sparse else f0
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved
d0, d1 = (
hairer_norm(y0 / sc).max().item(),
hairer_norm(f0.to_dense() / sc).max().item(),
hairer_norm(f0 / sc).max().item(),
)

if d0 < 1e-5 or d1 < 1e-5:
Expand All @@ -138,7 +139,8 @@ def init_tstep(

y1 = y0 + h0 * f0
f1 = fun(t0 + h0, y1)
d2 = hairer_norm((f1 - f0).to_dense() / sc).max().item() / h0
diff = (f1 - f0).to_dense() if self.options.use_sparse else f1 - f0
d2 = hairer_norm(diff / sc).max().item() / h0
if d1 <= 1e-15 and d2 <= 1e-15:
h1 = max(1e-6, h0 * 1e-3)
else:
Expand Down Expand Up @@ -183,13 +185,17 @@ def run(self) -> Tensor:

# initialize the ODE routine
t, y, *args = self.init_forward()
n1, n2 = y.shape

# run the ODE routine
result = []
for tnext in self.tsave:
y, *args = self.integrate(t, tnext, y, *args)
result.append(y.mH if n1 == n2 else y.T)
result.append(y)
t = tnext

return torch.cat(result).reshape((-1, n1, n2))
if len(y.shape) == 2:
res = torch.stack(result)
elif len(y.shape) == 3:
res = torch.stack(result).permute(0, 2, 3, 1)

return res
22 changes: 12 additions & 10 deletions pyqtorch/time_dependent/integrators/krylov.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ def integrate(self, t0: float, t1: float, y: Tensor) -> Tensor:
pass

def run(self) -> Tensor:
t = self.t0

# run the Krylov routine
result = []
y = self.y0
for tnext in self.tsave:
y = self.integrate(t, tnext, y)
result.append(y.T)
t = tnext

return torch.cat(result).unsqueeze(-1)
out = []
for i in range(self.y0.shape[1]):
# run the Krylov routine
result = []
y = self.y0[:, i : i + 1]
t = self.t0
for tnext in self.tsave:
y = self.integrate(t, tnext, y)
result.append(y.T)
t = tnext
out.append(torch.cat(result))
return torch.stack(out, dim=2)
26 changes: 15 additions & 11 deletions pyqtorch/time_dependent/mesolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def mesolve(
H: Callable[..., Any],
psi0: Tensor,
rho0: Tensor,
L: list[Tensor],
tsave: list | Tensor,
solver: SolverType,
Expand All @@ -22,7 +22,7 @@ def mesolve(

Args:
H (Callable[[float], Tensor]): time-dependent Hamiltonian of the system
psi0 (Tensor): initial state or density matrix of the system
rho0 (Tensor): initial density matrix of the system
L (list[Tensor]): list of jump operators
tsave (Tensor): tensor containing simulation time instants
solver (SolverType): name of the solver to use
Expand All @@ -34,18 +34,22 @@ def mesolve(
"""
options = options or dict()
L = torch.stack(L)
if psi0.size(-2) == 1:
rho0 = psi0.mH @ psi0
elif psi0.size(-1) == 1:
rho0 = psi0 @ psi0.mH
elif psi0.size(-1) == psi0.size(-2):
rho0 = psi0
else:

# check dimensions of initial state
n = H(0.0).shape[0]
if (
(rho0.shape[0] != rho0.shape[1])
or (rho0.shape[0] != n)
or (len(rho0.shape) != 3)
):
raise ValueError(
"Argument `psi0` must be a ket, bra or density matrix, but has shape"
f" {tuple(psi0.shape)}."
f"Argument `rho0` must be a 3D tensor of shape `({n}, {n}, batch_size)`. "
f"Current shape: {tuple(rho0.shape)}."
)

# permute dimensions to allow batch operations
rho0 = rho0.permute(2, 0, 1)

# instantiate appropriate solver
if solver == SolverType.DP5_ME:
opt = AdaptiveSolverOptions(**options)
Expand Down
5 changes: 3 additions & 2 deletions pyqtorch/time_dependent/methods/dp5.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ def step(

# compute iterated Runge-Kutta values
k = torch.empty(7, *f.shape, dtype=self.options.ctype)
k[0] = f.to_dense()
k[0] = f.to_dense() if self.options.use_sparse else f
for i in range(1, 7):
dy = torch.tensordot(dt * beta[i - 1, :i], k[:i].clone(), dims=([0], [0]))
k[i] = fun(t + dt * alpha[i - 1].item(), y + dy).to_dense()
a = fun(t + dt * alpha[i - 1].item(), y + dy)
k[i] = a.to_dense() if self.options.use_sparse else a

# compute results
f_new = k[-1]
Expand Down
2 changes: 2 additions & 0 deletions pyqtorch/time_dependent/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class AdaptiveSolverOptions:
max_factor: float = 5.0
ctype: dtype = torch.complex128
rtype: dtype = torch.float64
use_sparse: bool = False


@dataclass
Expand All @@ -25,3 +26,4 @@ class KrylovSolverOptions:
max_krylov: int = 80
exp_tolerance: float = 1e-10
norm_tolerance: float = 1e-10
use_sparse: bool = False
8 changes: 8 additions & 0 deletions pyqtorch/time_dependent/sesolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def sesolve(
Returns:
Result: dataclass containing the simulated states at each time moment
"""
# check dimensions of initial state
n = H(0.0).shape[0]
if (psi0.shape[0] != n) or len(psi0.shape) != 2:
raise ValueError(
f"Argument `psi0` must be a 2D tensor of shape `({n}, batch_size)`. Current shape:"
f" {tuple(psi0.shape)}."
)

options = options or dict()
# instantiate appropriate solver
if solver == SolverType.DP5_SE:
Expand Down
4 changes: 3 additions & 1 deletion pyqtorch/time_dependent/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def ode_fun(self, t: float, psi: Tensor) -> Tensor:
with warnings.catch_warnings():
# filter-out UserWarning about "Sparse CSR tensor support is in beta state"
warnings.filterwarnings("ignore", category=UserWarning)
res = -1j * self.H(t) @ psi.to_sparse()
res = (
-1j * self.H(t) @ (psi.to_sparse() if self.options.use_sparse else psi)
)
return res


Expand Down
17 changes: 13 additions & 4 deletions tests/test_analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,15 @@ def test_hamevo_endianness_cnot() -> None:


@pytest.mark.parametrize("duration", [torch.rand(1), "duration"])
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("use_sparse", [True, False])
@pytest.mark.parametrize("ode_solver", [SolverType.DP5_SE, SolverType.KRYLOV_SE])
def test_timedependent(
tparam: str,
param_y: float,
duration: float,
batch_size: int,
use_sparse: bool,
n_steps: int,
torch_hamiltonian: Callable,
hamevo_generator: Sequence,
Expand All @@ -306,14 +310,18 @@ def test_timedependent(
ode_solver: SolverType,
) -> None:

psi_start = random_state(2)
psi_start = random_state(2, batch_size)

dur_val = duration if isinstance(duration, torch.Tensor) else torch.rand(1)

# simulate with time-dependent solver
t_points = torch.linspace(0, dur_val[0], n_steps)
psi_solver = pyq.sesolve(
torch_hamiltonian, psi_start.reshape(-1, 1), t_points, ode_solver
torch_hamiltonian,
psi_start.reshape(-1, batch_size),
t_points,
ode_solver,
options={"use_sparse": use_sparse},
).states[-1]

# simulate with HamiltonianEvolution
Expand All @@ -324,14 +332,15 @@ def test_timedependent(
hamiltonian_evolution = pyq.HamiltonianEvolution(
generator=hamevo_generator,
time=tparam,
duration=duration,
duration=dur_val,
steps=n_steps,
solver=ode_solver,
use_sparse=use_sparse,
)
values = {"y": param_y, "duration": dur_val}
psi_hamevo = hamiltonian_evolution(
state=psi_start, values=values, embedding=embedding
).reshape(-1, 1)
).reshape(-1, batch_size)

assert torch.allclose(psi_solver, psi_hamevo, rtol=RTOL, atol=ATOL)

Expand Down
22 changes: 17 additions & 5 deletions tests/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@


@pytest.mark.flaky(max_runs=5)
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("ode_solver", [SolverType.DP5_SE, SolverType.KRYLOV_SE])
def test_sesolve(
duration: float,
batch_size: int,
n_steps: int,
torch_hamiltonian: Callable,
qutip_hamiltonian: Callable,
ode_solver: SolverType,
) -> None:

psi0_qutip = qutip.basis(4, 0)

# simulate with torch-based solver
psi0_torch = torch.tensor(psi0_qutip.full()).to(torch.complex128)
psi0_torch = (
torch.tensor(psi0_qutip.full()).to(torch.complex128).repeat(1, batch_size)
)

t_points = torch.linspace(0, duration, n_steps)
state_torch = sesolve(torch_hamiltonian, psi0_torch, t_points, ode_solver).states[
-1
Expand All @@ -38,15 +42,17 @@ def test_sesolve(
# simulate with qutip solver
t_points = np.linspace(0, duration, n_steps)
result = qutip.sesolve(qutip_hamiltonian, psi0_qutip, t_points)
state_qutip = torch.tensor(result.states[-1].full())
state_qutip = torch.tensor(result.states[-1].full()).repeat(1, batch_size)

assert torch.allclose(state_torch, state_qutip, atol=ATOL)


@pytest.mark.flaky(max_runs=5)
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("ode_solver", [SolverType.DP5_ME])
def test_mesolve(
duration: float,
batch_size: int,
n_steps: int,
torch_hamiltonian: Callable,
qutip_hamiltonian: Callable,
Expand All @@ -58,14 +64,20 @@ def test_mesolve(

# simulate with torch-based solver
psi0_torch = torch.tensor(psi0_qutip.full()).to(torch.complex128)
rho0_torch = (
torch.matmul(psi0_torch, psi0_torch.T).unsqueeze(-1).repeat(1, 1, batch_size)
)

t_points = torch.linspace(0, duration, n_steps)
state_torch = mesolve(
torch_hamiltonian, psi0_torch, jump_op_torch, t_points, ode_solver
torch_hamiltonian, rho0_torch, jump_op_torch, t_points, ode_solver
).states[-1]

# simulate with qutip solver
t_points = np.linspace(0, duration, n_steps)
result = qutip.mesolve(qutip_hamiltonian, psi0_qutip, t_points, jump_op_qutip)
state_qutip = torch.tensor(result.states[-1].full())
state_qutip = (
torch.tensor(result.states[-1].full()).unsqueeze(-1).repeat(1, 1, batch_size)
)

assert torch.allclose(state_torch, state_qutip, atol=ATOL)