From 3ede78818d7dc0c062b1bd1fa4dafbad6f03c1bf Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 19 Aug 2024 14:13:04 +0200 Subject: [PATCH] add test time init + better finite_diff writing --- pyqtorch/hamiltonians/evolution.py | 7 ++++++- pyqtorch/utils.py | 25 ++++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/pyqtorch/hamiltonians/evolution.py b/pyqtorch/hamiltonians/evolution.py index aad6ff37..b11f4e05 100644 --- a/pyqtorch/hamiltonians/evolution.py +++ b/pyqtorch/hamiltonians/evolution.py @@ -197,7 +197,12 @@ def __init__( ) super().__init__(generator) self._qubit_support = qubit_support # type: ignore - self.time = time + + if isinstance(time, str) or isinstance(time, Tensor): + self.time = time + else: + raise ValueError("time should be passed as str or tensor.") + logger.debug("Hamiltonian Evolution initialized") if logger.isEnabledFor(logging.DEBUG): # When Debugging let's add logging and NVTX markers diff --git a/pyqtorch/utils.py b/pyqtorch/utils.py index cdcce531..34270dc5 100644 --- a/pyqtorch/utils.py +++ b/pyqtorch/utils.py @@ -247,29 +247,32 @@ def finitediff( eps = torch.finfo(x.dtype).eps ** (1 / (2 + order)) # compute derivative direction vector(s) - eps = torch.as_tensor(eps, dtype=x.dtype) - _eps = 1 / eps # type: ignore[operator] - ev = torch.zeros_like(x) + delta = torch.zeros_like(x) i = derivative_indices[0] - ev[:, i] += eps + delta[:, i] += torch.as_tensor(eps, dtype=x.dtype) + denominator = 1 / delta[:, i] # recursive finite differencing for higher order than 3 / mixed derivatives if len(derivative_indices) > 3 or len(set(derivative_indices)) > 1: di = derivative_indices[1:] - return (finitediff(f, x + ev, di) - finitediff(f, x - ev, di)) * _eps / 2 + return ( + (finitediff(f, x + delta, di) - finitediff(f, x - delta, di)) + * denominator + / 2 + ) if len(derivative_indices) == 3: return ( - (f(x + 2 * ev) - 2 * f(x + ev) + 2 * f(x - ev) - f(x - 2 * ev)) - * _eps**3 + (f(x + 2 * delta) - 2 * f(x + delta) + 2 * f(x - delta) - f(x - 2 * delta)) + * denominator**3 / 2 ) if len(derivative_indices) == 2: - return (f(x + ev) + f(x - ev) - 2 * f(x)) * _eps**2 + return (f(x + delta) + f(x - delta) - 2 * f(x)) * denominator**2 if len(derivative_indices) == 1: - return (f(x + ev) - f(x - ev)) * _eps / 2 + return (f(x + delta) - f(x - delta)) * denominator / 2 raise ValueError( - "If you see this error there is a bug in the `finitediff` function." - ) + "If you see this error there is a bug in the `finitediff` function." + ) def product_state(