From a11da3716f7200269c3115d154046ed7ed873d10 Mon Sep 17 00:00:00 2001 From: Vytautas Abramavicius <145791635+vytautas-a@users.noreply.github.com> Date: Fri, 15 Nov 2024 15:32:07 +0200 Subject: [PATCH] [Feature] Add batch dimension to input states for time-dependent solvers (#294) Add batch dimension to input states for time-dependent solvers. Also introduce flag in config of time-dependent solvers to choose sparse/dense tensors for calculations. Closes #292 and #274 --- docs/time_dependent.md | 9 +++++-- pyqtorch/hamiltonians/evolution.py | 5 +++- .../time_dependent/integrators/adaptive.py | 16 ++++++++---- pyqtorch/time_dependent/integrators/krylov.py | 22 +++++++++------- pyqtorch/time_dependent/mesolve.py | 26 +++++++++++-------- pyqtorch/time_dependent/methods/dp5.py | 5 ++-- pyqtorch/time_dependent/options.py | 2 ++ pyqtorch/time_dependent/sesolve.py | 8 ++++++ pyqtorch/time_dependent/solvers.py | 4 ++- tests/test_analog.py | 17 +++++++++--- tests/test_solvers.py | 22 ++++++++++++---- 11 files changed, 95 insertions(+), 41 deletions(-) diff --git a/docs/time_dependent.md b/docs/time_dependent.md index 6fdcbc79..59d9d2af 100644 --- a/docs/time_dependent.md +++ b/docs/time_dependent.md @@ -30,10 +30,15 @@ def ham_t(t: float) -> Tensor: t_points = torch.linspace(0, duration, n_steps) final_state_se = sesolve(ham_t, input_state, t_points, SolverType.DP5_SE).states[-1] -# define jump operator L and solve Lindblad master equation +# define jump operator L L = IMAT.clone() for i in range(n_qubits-1): L = torch.kron(L, XMAT) -final_state_me = mesolve(ham_t, input_state, [L], t_points, SolverType.DP5_ME).states[-1] + +# prepare initial density matrix with batch dimension as the last +rho0 = torch.matmul(input_state, input_state.T).unsqueeze(-1) + +# solve Lindblad master equation +final_state_me = mesolve(ham_t, rho0, [L], t_points, SolverType.DP5_ME).states[-1] ``` diff --git a/pyqtorch/hamiltonians/evolution.py b/pyqtorch/hamiltonians/evolution.py index d819a128..9d46c3af 100644 --- a/pyqtorch/hamiltonians/evolution.py +++ b/pyqtorch/hamiltonians/evolution.py @@ -150,7 +150,8 @@ def __init__( cache_length: int = 1, duration: Tensor | str | float | None = None, steps: int = 100, - solver=SolverType.DP5_SE, + solver: SolverType = SolverType.DP5_SE, + use_sparse: bool = False, ): """Initializes the HamiltonianEvolution. Depending on the generator argument, set the type and set the right generator getter. @@ -167,6 +168,7 @@ def __init__( self.solver_type = solver self.steps = steps self.duration = duration + self.use_sparse = use_sparse if isinstance(duration, (str, float, Tensor)) or duration is None: self.duration = duration @@ -416,6 +418,7 @@ def Ht(t: torch.Tensor) -> torch.Tensor: torch.flatten(state, start_dim=0, end_dim=-2), t_grid, self.solver_type, + options={"use_sparse": self.use_sparse}, ) # Retrieve the last state of shape (2**n_qubits, batch_size) diff --git a/pyqtorch/time_dependent/integrators/adaptive.py b/pyqtorch/time_dependent/integrators/adaptive.py index fd5304a9..60e4be88 100644 --- a/pyqtorch/time_dependent/integrators/adaptive.py +++ b/pyqtorch/time_dependent/integrators/adaptive.py @@ -126,9 +126,10 @@ def init_tstep( """ sc = self.options.atol + torch.abs(y0) * self.options.rtol + f0 = f0.to_dense() if self.options.use_sparse else f0 d0, d1 = ( hairer_norm(y0 / sc).max().item(), - hairer_norm(f0.to_dense() / sc).max().item(), + hairer_norm(f0 / sc).max().item(), ) if d0 < 1e-5 or d1 < 1e-5: @@ -138,7 +139,8 @@ def init_tstep( y1 = y0 + h0 * f0 f1 = fun(t0 + h0, y1) - d2 = hairer_norm((f1 - f0).to_dense() / sc).max().item() / h0 + diff = (f1 - f0).to_dense() if self.options.use_sparse else f1 - f0 + d2 = hairer_norm(diff / sc).max().item() / h0 if d1 <= 1e-15 and d2 <= 1e-15: h1 = max(1e-6, h0 * 1e-3) else: @@ -183,13 +185,17 @@ def run(self) -> Tensor: # initialize the ODE routine t, y, *args = self.init_forward() - n1, n2 = y.shape # run the ODE routine result = [] for tnext in self.tsave: y, *args = self.integrate(t, tnext, y, *args) - result.append(y.mH if n1 == n2 else y.T) + result.append(y) t = tnext - return torch.cat(result).reshape((-1, n1, n2)) + if len(y.shape) == 2: + res = torch.stack(result) + elif len(y.shape) == 3: + res = torch.stack(result).permute(0, 2, 3, 1) + + return res diff --git a/pyqtorch/time_dependent/integrators/krylov.py b/pyqtorch/time_dependent/integrators/krylov.py index 99be0587..7c290a81 100644 --- a/pyqtorch/time_dependent/integrators/krylov.py +++ b/pyqtorch/time_dependent/integrators/krylov.py @@ -30,14 +30,16 @@ def integrate(self, t0: float, t1: float, y: Tensor) -> Tensor: pass def run(self) -> Tensor: - t = self.t0 - # run the Krylov routine - result = [] - y = self.y0 - for tnext in self.tsave: - y = self.integrate(t, tnext, y) - result.append(y.T) - t = tnext - - return torch.cat(result).unsqueeze(-1) + out = [] + for i in range(self.y0.shape[1]): + # run the Krylov routine + result = [] + y = self.y0[:, i : i + 1] + t = self.t0 + for tnext in self.tsave: + y = self.integrate(t, tnext, y) + result.append(y.T) + t = tnext + out.append(torch.cat(result)) + return torch.stack(out, dim=2) diff --git a/pyqtorch/time_dependent/mesolve.py b/pyqtorch/time_dependent/mesolve.py index e429781f..3bede864 100644 --- a/pyqtorch/time_dependent/mesolve.py +++ b/pyqtorch/time_dependent/mesolve.py @@ -12,7 +12,7 @@ def mesolve( H: Callable[..., Any], - psi0: Tensor, + rho0: Tensor, L: list[Tensor], tsave: list | Tensor, solver: SolverType, @@ -22,7 +22,7 @@ def mesolve( Args: H (Callable[[float], Tensor]): time-dependent Hamiltonian of the system - psi0 (Tensor): initial state or density matrix of the system + rho0 (Tensor): initial density matrix of the system L (list[Tensor]): list of jump operators tsave (Tensor): tensor containing simulation time instants solver (SolverType): name of the solver to use @@ -34,18 +34,22 @@ def mesolve( """ options = options or dict() L = torch.stack(L) - if psi0.size(-2) == 1: - rho0 = psi0.mH @ psi0 - elif psi0.size(-1) == 1: - rho0 = psi0 @ psi0.mH - elif psi0.size(-1) == psi0.size(-2): - rho0 = psi0 - else: + + # check dimensions of initial state + n = H(0.0).shape[0] + if ( + (rho0.shape[0] != rho0.shape[1]) + or (rho0.shape[0] != n) + or (len(rho0.shape) != 3) + ): raise ValueError( - "Argument `psi0` must be a ket, bra or density matrix, but has shape" - f" {tuple(psi0.shape)}." + f"Argument `rho0` must be a 3D tensor of shape `({n}, {n}, batch_size)`. " + f"Current shape: {tuple(rho0.shape)}." ) + # permute dimensions to allow batch operations + rho0 = rho0.permute(2, 0, 1) + # instantiate appropriate solver if solver == SolverType.DP5_ME: opt = AdaptiveSolverOptions(**options) diff --git a/pyqtorch/time_dependent/methods/dp5.py b/pyqtorch/time_dependent/methods/dp5.py index 78c0c054..9b4a2cb8 100644 --- a/pyqtorch/time_dependent/methods/dp5.py +++ b/pyqtorch/time_dependent/methods/dp5.py @@ -75,10 +75,11 @@ def step( # compute iterated Runge-Kutta values k = torch.empty(7, *f.shape, dtype=self.options.ctype) - k[0] = f.to_dense() + k[0] = f.to_dense() if self.options.use_sparse else f for i in range(1, 7): dy = torch.tensordot(dt * beta[i - 1, :i], k[:i].clone(), dims=([0], [0])) - k[i] = fun(t + dt * alpha[i - 1].item(), y + dy).to_dense() + a = fun(t + dt * alpha[i - 1].item(), y + dy) + k[i] = a.to_dense() if self.options.use_sparse else a # compute results f_new = k[-1] diff --git a/pyqtorch/time_dependent/options.py b/pyqtorch/time_dependent/options.py index ae29fcad..a21b769b 100644 --- a/pyqtorch/time_dependent/options.py +++ b/pyqtorch/time_dependent/options.py @@ -17,6 +17,7 @@ class AdaptiveSolverOptions: max_factor: float = 5.0 ctype: dtype = torch.complex128 rtype: dtype = torch.float64 + use_sparse: bool = False @dataclass @@ -25,3 +26,4 @@ class KrylovSolverOptions: max_krylov: int = 80 exp_tolerance: float = 1e-10 norm_tolerance: float = 1e-10 + use_sparse: bool = False diff --git a/pyqtorch/time_dependent/sesolve.py b/pyqtorch/time_dependent/sesolve.py index 45092b00..3c463fb7 100644 --- a/pyqtorch/time_dependent/sesolve.py +++ b/pyqtorch/time_dependent/sesolve.py @@ -29,6 +29,14 @@ def sesolve( Returns: Result: dataclass containing the simulated states at each time moment """ + # check dimensions of initial state + n = H(0.0).shape[0] + if (psi0.shape[0] != n) or len(psi0.shape) != 2: + raise ValueError( + f"Argument `psi0` must be a 2D tensor of shape `({n}, batch_size)`. Current shape:" + f" {tuple(psi0.shape)}." + ) + options = options or dict() # instantiate appropriate solver if solver == SolverType.DP5_SE: diff --git a/pyqtorch/time_dependent/solvers.py b/pyqtorch/time_dependent/solvers.py index 480106f0..22e4826a 100644 --- a/pyqtorch/time_dependent/solvers.py +++ b/pyqtorch/time_dependent/solvers.py @@ -30,7 +30,9 @@ def ode_fun(self, t: float, psi: Tensor) -> Tensor: with warnings.catch_warnings(): # filter-out UserWarning about "Sparse CSR tensor support is in beta state" warnings.filterwarnings("ignore", category=UserWarning) - res = -1j * self.H(t) @ psi.to_sparse() + res = ( + -1j * self.H(t) @ (psi.to_sparse() if self.options.use_sparse else psi) + ) return res diff --git a/tests/test_analog.py b/tests/test_analog.py index 64ac0d3d..533f3c79 100644 --- a/tests/test_analog.py +++ b/tests/test_analog.py @@ -293,11 +293,15 @@ def test_hamevo_endianness_cnot() -> None: @pytest.mark.parametrize("duration", [torch.rand(1), "duration"]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("use_sparse", [True, False]) @pytest.mark.parametrize("ode_solver", [SolverType.DP5_SE, SolverType.KRYLOV_SE]) def test_timedependent( tparam: str, param_y: float, duration: float, + batch_size: int, + use_sparse: bool, n_steps: int, torch_hamiltonian: Callable, hamevo_generator: Sequence, @@ -306,14 +310,18 @@ def test_timedependent( ode_solver: SolverType, ) -> None: - psi_start = random_state(2) + psi_start = random_state(2, batch_size) dur_val = duration if isinstance(duration, torch.Tensor) else torch.rand(1) # simulate with time-dependent solver t_points = torch.linspace(0, dur_val[0], n_steps) psi_solver = pyq.sesolve( - torch_hamiltonian, psi_start.reshape(-1, 1), t_points, ode_solver + torch_hamiltonian, + psi_start.reshape(-1, batch_size), + t_points, + ode_solver, + options={"use_sparse": use_sparse}, ).states[-1] # simulate with HamiltonianEvolution @@ -324,14 +332,15 @@ def test_timedependent( hamiltonian_evolution = pyq.HamiltonianEvolution( generator=hamevo_generator, time=tparam, - duration=duration, + duration=dur_val, steps=n_steps, solver=ode_solver, + use_sparse=use_sparse, ) values = {"y": param_y, "duration": dur_val} psi_hamevo = hamiltonian_evolution( state=psi_start, values=values, embedding=embedding - ).reshape(-1, 1) + ).reshape(-1, batch_size) assert torch.allclose(psi_solver, psi_hamevo, rtol=RTOL, atol=ATOL) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index f6bc51ef..7a17c854 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -17,19 +17,23 @@ @pytest.mark.flaky(max_runs=5) +@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("ode_solver", [SolverType.DP5_SE, SolverType.KRYLOV_SE]) def test_sesolve( duration: float, + batch_size: int, n_steps: int, torch_hamiltonian: Callable, qutip_hamiltonian: Callable, ode_solver: SolverType, ) -> None: - psi0_qutip = qutip.basis(4, 0) # simulate with torch-based solver - psi0_torch = torch.tensor(psi0_qutip.full()).to(torch.complex128) + psi0_torch = ( + torch.tensor(psi0_qutip.full()).to(torch.complex128).repeat(1, batch_size) + ) + t_points = torch.linspace(0, duration, n_steps) state_torch = sesolve(torch_hamiltonian, psi0_torch, t_points, ode_solver).states[ -1 @@ -38,15 +42,17 @@ def test_sesolve( # simulate with qutip solver t_points = np.linspace(0, duration, n_steps) result = qutip.sesolve(qutip_hamiltonian, psi0_qutip, t_points) - state_qutip = torch.tensor(result.states[-1].full()) + state_qutip = torch.tensor(result.states[-1].full()).repeat(1, batch_size) assert torch.allclose(state_torch, state_qutip, atol=ATOL) @pytest.mark.flaky(max_runs=5) +@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("ode_solver", [SolverType.DP5_ME]) def test_mesolve( duration: float, + batch_size: int, n_steps: int, torch_hamiltonian: Callable, qutip_hamiltonian: Callable, @@ -58,14 +64,20 @@ def test_mesolve( # simulate with torch-based solver psi0_torch = torch.tensor(psi0_qutip.full()).to(torch.complex128) + rho0_torch = ( + torch.matmul(psi0_torch, psi0_torch.T).unsqueeze(-1).repeat(1, 1, batch_size) + ) + t_points = torch.linspace(0, duration, n_steps) state_torch = mesolve( - torch_hamiltonian, psi0_torch, jump_op_torch, t_points, ode_solver + torch_hamiltonian, rho0_torch, jump_op_torch, t_points, ode_solver ).states[-1] # simulate with qutip solver t_points = np.linspace(0, duration, n_steps) result = qutip.mesolve(qutip_hamiltonian, psi0_qutip, t_points, jump_op_qutip) - state_qutip = torch.tensor(result.states[-1].full()) + state_qutip = ( + torch.tensor(result.states[-1].full()).unsqueeze(-1).repeat(1, 1, batch_size) + ) assert torch.allclose(state_torch, state_qutip, atol=ATOL)